diff --git a/src/main.py b/src/main.py index 078f6646..eabbe0ca 100644 --- a/src/main.py +++ b/src/main.py @@ -171,6 +171,28 @@ def create_workflow(selected_analysts=None): # Parse tickers from comma-separated string tickers = [ticker.strip() for ticker in args.tickers.split(",")] + # Validate tickers before proceeding + from tools.api import validate_tickers + is_valid, invalid_tickers = validate_tickers(tickers) + if not is_valid: + valid_tickers = [ticker for ticker in tickers if ticker not in invalid_tickers] + if not valid_tickers: + print(f"{Fore.RED}Error: All provided tickers are invalid: {', '.join(invalid_tickers)}{Style.RESET_ALL}") + sys.exit(1) + + print(f"{Fore.YELLOW}Warning: The following tickers are invalid: {', '.join(invalid_tickers)}{Style.RESET_ALL}") + proceed = questionary.confirm( + f"Do you want to proceed with only the valid tickers: {', '.join(valid_tickers)}?", + default=True + ).ask() + + if not proceed: + print("\nExiting...") + sys.exit(0) + + print(f"\nProceeding with tickers: {', '.join(valid_tickers)}\n") + tickers = valid_tickers + # Select analysts selected_analysts = None choices = questionary.checkbox( @@ -193,7 +215,7 @@ def create_workflow(selected_analysts=None): sys.exit(0) else: selected_analysts = choices - print(f"\nSelected analysts: {', '.join(Fore.GREEN + choice.title().replace('_', ' ') + Style.RESET_ALL for choice in choices)}\n") + print(f"\nSelected analysts: {', '.join(choice.title().replace('_', ' ') for choice in choices)}\n") # Select LLM model model_choice = questionary.select( diff --git a/src/tools/api.py b/src/tools/api.py index 2b2c9c01..9cdea719 100644 --- a/src/tools/api.py +++ b/src/tools/api.py @@ -280,3 +280,21 @@ def prices_to_df(prices: list[Price]) -> pd.DataFrame: def get_price_data(ticker: str, start_date: str, end_date: str) -> pd.DataFrame: prices = get_prices(ticker, start_date, end_date) return prices_to_df(prices) + + +def validate_tickers(tickers: list[str]) -> tuple[bool, list[str]]: + """ + Validate a list of tickers against the available tickers list. + Returns a tuple of (is_valid, invalid_tickers). + """ + try: + response = requests.get("https://virattt.github.io/datasets/financials/available_tickers.json") + if response.status_code != 200: + raise Exception(f"Error fetching valid tickers: {response.status_code} - {response.text}") + + valid_tickers = {ticker["symbol"] for ticker in response.json()["tickers"]} + invalid_tickers = [ticker for ticker in tickers if ticker not in valid_tickers] + + return len(invalid_tickers) == 0, invalid_tickers + except Exception as e: + raise Exception(f"Error validating tickers: {str(e)}")