diff --git a/aishell/cli.py b/aishell/cli.py index 3441437..6f07bb7 100644 --- a/aishell/cli.py +++ b/aishell/cli.py @@ -64,14 +64,18 @@ def ask(question: str, language_model: LanguageModel = LanguageModel.REVERSE_ENG config_manager = AiShellConfigManager(load_config=True) else: config_manager = config_aishell() + config_manager.config_model.language_model = language_model + configured_language_model = config_manager.config_model.language_model query_client: QueryClient - if language_model == LanguageModel.REVERSE_ENGINEERED_CHATGPT: + if configured_language_model == LanguageModel.REVERSE_ENGINEERED_CHATGPT: query_client = ReverseEngineeredChatGPTClient(config=config_manager.config_model.chatgpt_config) - elif language_model == LanguageModel.GPT3: + elif configured_language_model == LanguageModel.GPT3: query_client = GPT3Client() - elif language_model == LanguageModel.OFFICIAL_CHATGPT: + elif configured_language_model == LanguageModel.OFFICIAL_CHATGPT: query_client = OfficialChatGPTClient() + else: + raise NotImplementedError(f'Language model {configured_language_model} is not implemented yet.') query_client.query(question)