Improved code by renaming variables, adding more asserts and moving code segments

This commit is contained in:
Julian Müller (ChaoticByte) 2023-03-11 14:10:13 +01:00
parent e8b3b97da9
commit 4fd0969457

View file

@ -21,15 +21,20 @@ class Message:
'''Message type. Supports roles.''' '''Message type. Supports roles.'''
def __init__(self, text: str, role: str = Roles.USER): def __init__(self, text: str, role: str = Roles.USER):
assert type(text) == str
assert type(role) == str
self.text = text self.text = text
self.role = role self.role = role
@classmethod @classmethod
def from_api(cls, message_dict:str): def from_api(cls, api_msg: dict):
'''Create a Message object from API format''' '''Create a Message object from API format'''
assert type(api_msg) == dict
msg = api_msg["choices"][0]["message"]
msg["content"] = msg["content"].strip("\n")
return cls( return cls(
message_dict["content"], msg["content"],
message_dict["role"]) msg["role"])
def to_api(self): def to_api(self):
'''Convert to API format''' '''Convert to API format'''
@ -42,6 +47,8 @@ class ChatGPT:
API_ENDPOINT = "https://api.openai.com/v1/chat/completions" API_ENDPOINT = "https://api.openai.com/v1/chat/completions"
def __init__(self, api_key: str, model: str = Models.GPT_35_TURBO): def __init__(self, api_key: str, model: str = Models.GPT_35_TURBO):
assert type(api_key) == str
assert type(model) == str
# Create string used in header # Create string used in header
self.auth = f"Bearer {api_key}" self.auth = f"Bearer {api_key}"
# This list will contain all prior messages # This list will contain all prior messages
@ -51,12 +58,12 @@ class ChatGPT:
def chat(self, message: Message) -> Message: def chat(self, message: Message) -> Message:
'''Add a message to the message history & send it to ChatGPT. Returns the answer as a Message instance.''' '''Add a message to the message history & send it to ChatGPT. Returns the answer as a Message instance.'''
self.add_to_chat(message) self.add_to_chat(message)
# create api_input from message_history & encode it # Create api_input from message_history & encode it
api_input = [m.to_api() for m in self._message_history] api_input = [m.to_api() for m in self._message_history]
api_input_encoded = dumps( api_input_encoded = dumps(
{"model": self.model, "messages": api_input}, {"model": self.model, "messages": api_input},
separators=(",", ":")).encode() separators=(",", ":")).encode()
# create a Request object with the right url, data, headers and http method # Create a Request object with the right url, data, headers and http method
request = http_request.Request( request = http_request.Request(
self.API_ENDPOINT, self.API_ENDPOINT,
data=api_input_encoded, data=api_input_encoded,
@ -65,21 +72,18 @@ class ChatGPT:
"Content-Type": "application/json" "Content-Type": "application/json"
}, },
method="POST") method="POST")
# send the request with r as the response # Send the request with r as the response
with http_request.urlopen(request) as r: with http_request.urlopen(request) as r:
# read response and parse json # Read response and parse json
api_output = loads(r.read()) api_output = loads(r.read())
api_output_answer = api_output["choices"][0]["message"] # Convert to Message object
# remove leading and trailing newlines response_message = Message.from_api(api_output)
api_output_answer["content"] = api_output_answer["content"].strip("\n")
# convert to Message object
response_message = Message.from_api(api_output_answer)
self._message_history.append(response_message) self._message_history.append(response_message)
return response_message return response_message
def add_to_chat(self, message: Message) -> Message: def add_to_chat(self, message: Message) -> Message:
'''Add a message to the message history without sending it to ChatGPT''' '''Add a message to the message history without sending it to ChatGPT'''
# check if the message parameter is the correct type # Check if the message parameter is the correct type
assert type(message) == Message, "message must be an instance of Message" assert type(message) == Message, "message must be an instance of Message"
self._message_history.append(message) self._message_history.append(message)