diff --git a/chatgpt_pyapi/__init__.py b/chatgpt_pyapi/__init__.py index 9f80222..566d26f 100644 --- a/chatgpt_pyapi/__init__.py +++ b/chatgpt_pyapi/__init__.py @@ -21,15 +21,20 @@ class Message: '''Message type. Supports roles.''' def __init__(self, text: str, role: str = Roles.USER): + assert type(text) == str + assert type(role) == str self.text = text self.role = role @classmethod - def from_api(cls, message_dict:str): + def from_api(cls, api_msg: dict): '''Create a Message object from API format''' + assert type(api_msg) == dict + msg = api_msg["choices"][0]["message"] + msg["content"] = msg["content"].strip("\n") return cls( - message_dict["content"], - message_dict["role"]) + msg["content"], + msg["role"]) def to_api(self): '''Convert to API format''' @@ -42,6 +47,8 @@ class ChatGPT: API_ENDPOINT = "https://api.openai.com/v1/chat/completions" def __init__(self, api_key: str, model: str = Models.GPT_35_TURBO): + assert type(api_key) == str + assert type(model) == str # Create string used in header self.auth = f"Bearer {api_key}" # This list will contain all prior messages @@ -51,12 +58,12 @@ class ChatGPT: def chat(self, message: Message) -> Message: '''Add a message to the message history & send it to ChatGPT. Returns the answer as a Message instance.''' self.add_to_chat(message) - # create api_input from message_history & encode it + # Create api_input from message_history & encode it api_input = [m.to_api() for m in self._message_history] api_input_encoded = dumps( {"model": self.model, "messages": api_input}, separators=(",", ":")).encode() - # create a Request object with the right url, data, headers and http method + # Create a Request object with the right url, data, headers and http method request = http_request.Request( self.API_ENDPOINT, data=api_input_encoded, @@ -65,21 +72,18 @@ class ChatGPT: "Content-Type": "application/json" }, method="POST") - # send the request with r as the response + # Send the request with r as the response with http_request.urlopen(request) as r: - # read response and parse json + # Read response and parse json api_output = loads(r.read()) - api_output_answer = api_output["choices"][0]["message"] - # remove leading and trailing newlines - api_output_answer["content"] = api_output_answer["content"].strip("\n") - # convert to Message object - response_message = Message.from_api(api_output_answer) + # Convert to Message object + response_message = Message.from_api(api_output) self._message_history.append(response_message) return response_message def add_to_chat(self, message: Message) -> Message: '''Add a message to the message history without sending it to ChatGPT''' - # check if the message parameter is the correct type + # Check if the message parameter is the correct type assert type(message) == Message, "message must be an instance of Message" self._message_history.append(message)