Skip to content

Commit bad5858

Browse files
authored
Merge pull request #29 from code-yeongyu/feature/improve-readibility
2 parents 6cad61e + b240a49 commit bad5858

7 files changed

+74
-37
lines changed

aishell/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .language_model import LanguageModel as LanguageModel
22
from .open_ai_response_model import OpenAIResponseModel as OpenAIResponseModel
3+
from .revchatgpt_chatbot_config_model import RevChatGPTChatbotConfigModel as RevChatGPTChatbotConfigModel

aishell/models/open_ai_response_model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Optional
22

33
from pydantic import BaseModel
44

@@ -16,9 +16,12 @@ class Usage(BaseModel):
1616
prompt_tokens: int
1717
total_tokens: int
1818

19-
choices: Optional[List[Choice]]
19+
choices: Optional[list[Choice]]
2020
created: int
2121
id: str
2222
model: str
2323
object: str
2424
usage: Usage
25+
26+
class Config:
27+
frozen = True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, root_validator
4+
5+
6+
class RevChatGPTChatbotConfigModel(BaseModel):
7+
email: Optional[str] = None
8+
password: Optional[str] = None
9+
session_token: Optional[str] = None
10+
access_token: Optional[str] = None
11+
paid: bool = False
12+
13+
@root_validator
14+
def check_at_least_one_account_info(cls, values: dict[str, Optional[str]]):
15+
IS_ACCOUNT_LOGIN = values.get('email') and values.get('password')
16+
IS_TOKEN_AUTH = values.get('session_token') or values.get('access_token')
17+
if not IS_ACCOUNT_LOGIN and not IS_TOKEN_AUTH:
18+
raise ValueError('No information for authentication provided.')
19+
20+
return values
21+
22+
class Config:
23+
frozen = True

aishell/query_clients/gpt3_client.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import cast
2+
from typing import Final, cast
33

44
import openai
55

@@ -11,16 +11,9 @@
1111

1212
class GPT3Client(QueryClient):
1313

14-
def _construct_prompt(self, text: str) -> str:
15-
return f'''User: You are now a translater from human language to {os.uname()[0]} shell command.
16-
No explanation required, respond with only the raw shell command.
17-
What should I type to shell for: {text}, in one line.
18-
19-
You: '''
20-
2114
def query(self, prompt: str) -> str:
2215
prompt = self._construct_prompt(prompt)
23-
completion: OpenAIResponseModel = cast( # type: ignore [no-any-unimported]
16+
completion: Final[OpenAIResponseModel] = cast(
2417
OpenAIResponseModel,
2518
openai.Completion.create(
2619
engine='text-davinci-003',
@@ -32,5 +25,12 @@ def query(self, prompt: str) -> str:
3225
)
3326
if not completion.choices or len(completion.choices) == 0 or not completion.choices[0].text:
3427
raise RuntimeError('No response from OpenAI')
35-
response_text: str = completion.choices[0].text
28+
response_text: Final[str] = completion.choices[0].text
3629
return make_executable_command(response_text)
30+
31+
def _construct_prompt(self, text: str) -> str:
32+
return f'''User: You are now a translater from human language to {os.uname()[0]} shell command.
33+
No explanation required, respond with only the raw shell command.
34+
What should I type to shell for: {text}, in one line.
35+
36+
You: '''
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Final, Optional
2+
from typing import Optional
33

44
from revChatGPT.V3 import Chatbot
55

@@ -10,28 +10,27 @@
1010

1111

1212
class OfficialChatGPTClient(QueryClient):
13-
openai_api_key: str
1413

1514
def __init__(
1615
self,
1716
openai_api_key: Optional[str] = None,
1817
):
1918
super().__init__()
20-
OPENAI_API_KEY: Final[Optional[str]] = os.environ.get('OPENAI_API_KEY', openai_api_key)
19+
OPENAI_API_KEY: Optional[str] = os.environ.get('OPENAI_API_KEY', openai_api_key)
2120
if OPENAI_API_KEY is None:
2221
raise UnauthorizedAccessError('OPENAI_API_KEY should not be none')
2322

24-
self.openai_api_key = OPENAI_API_KEY
25-
26-
def _construct_prompt(self, text: str) -> str:
27-
return f'''You are now a translater from human language to {os.uname()[0]} shell command.
28-
No explanation required, respond with only the raw shell command.
29-
What should I type to shell for: {text}, in one line.'''
23+
self.OPENAI_API_KEY = OPENAI_API_KEY
3024

3125
def query(self, prompt: str) -> str:
32-
prompt = self._construct_prompt(prompt)
26+
chatbot = Chatbot(api_key=self.OPENAI_API_KEY)
3327

34-
chatbot = Chatbot(api_key=self.openai_api_key)
28+
prompt = self._construct_prompt(prompt)
3529
response_text = chatbot.ask(prompt)
30+
executable_command = make_executable_command(response_text)
31+
return executable_command
3632

37-
return make_executable_command(response_text)
33+
def _construct_prompt(self, text: str) -> str:
34+
return f'''You are now a translater from human language to {os.uname()[0]} shell command.
35+
No explanation required, respond with only the raw shell command.
36+
What should I type to shell for: {text}, in one line.'''
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import os
2-
from typing import Optional, cast
2+
from typing import Optional, Union, cast
33

44
from revChatGPT.V1 import Chatbot
55

66
from aishell.exceptions import UnauthorizedAccessError
7+
from aishell.models import RevChatGPTChatbotConfigModel
78
from aishell.utils import make_executable_command
89

910
from .query_client import QueryClient
1011

1112

1213
class ReverseEngineeredChatGPTClient(QueryClient):
13-
config: dict[str, str] = {}
14+
_config: RevChatGPTChatbotConfigModel
15+
16+
@property
17+
def revchatgpt_config(self) -> dict[str, Union[str, bool]]:
18+
return self._config.dict(exclude_none=True)
1419

1520
def __init__(
1621
self,
@@ -19,21 +24,17 @@ def __init__(
1924
):
2025
CHATGPT_ACCESS_TOKEN = os.environ.get('CHATGPT_ACCESS_TOKEN', access_token)
2126
CHATGPT_SESSION_TOKEN = os.environ.get('CHATGPT_SESSION_TOKEN', session_token)
22-
if CHATGPT_ACCESS_TOKEN is not None:
23-
self.config['access_token'] = CHATGPT_ACCESS_TOKEN
24-
elif CHATGPT_SESSION_TOKEN is not None:
25-
self.config['session_token'] = CHATGPT_SESSION_TOKEN
27+
if CHATGPT_ACCESS_TOKEN:
28+
self._config = RevChatGPTChatbotConfigModel(access_token=CHATGPT_ACCESS_TOKEN)
29+
elif CHATGPT_SESSION_TOKEN:
30+
self._config = RevChatGPTChatbotConfigModel(session_token=CHATGPT_SESSION_TOKEN)
2631
else:
2732
raise UnauthorizedAccessError('No access token or session token provided.')
2833

29-
def _construct_prompt(self, text: str) -> str:
30-
return f'''You are now a translater from human language to {os.uname()[0]} shell command.
31-
No explanation required, respond with only the raw shell command.
32-
What should I type to shell for: {text}, in one line.'''
33-
3434
def query(self, prompt: str) -> str:
3535
prompt = self._construct_prompt(prompt)
36-
chatbot = Chatbot(config=self.config)
36+
chatbot = Chatbot(config=self.revchatgpt_config) # pyright: ignore [reportGeneralTypeIssues]
37+
# ignore for wrong type hint of revchatgpt
3738

3839
response_text = ''
3940
for data in chatbot.ask(prompt):
@@ -42,3 +43,8 @@ def query(self, prompt: str) -> str:
4243
response_text = make_executable_command(cast(str, response_text))
4344

4445
return response_text
46+
47+
def _construct_prompt(self, text: str) -> str:
48+
return f'''You are now a translater from human language to {os.uname()[0]} shell command.
49+
No explanation required, respond with only the raw shell command.
50+
What should I type to shell for: {text}, in one line.'''

aishell/utils/str_enum.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
class StrEnum(str, Enum):
66

7-
def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]): # type: ignore
7+
def _generate_next_value_( # pyright: ignore [reportIncompatibleMethodOverride], for pyright's bug
8+
name: str,
9+
start: int,
10+
count: int,
11+
last_values: list[Any],
12+
):
813
return name.lower()
914

1015
def __repr__(self):

0 commit comments

Comments
 (0)