A Simple OpenAI API Wrapper
I designed a simple Python wrapper around OpenAI’s API, specifically designed for personal use. The wrapper is focused on easily obtaining the desired output from OpenAI’s models while also allowing the flexibility to define the output format. Let’s break down what each part of the code does:
Overview of Class and Methods
__init__
: Initializes theLLM
(Language Learning Model, perhaps?) class with an API key, the name of the model (defaulting to "gpt-3.5-turbo"), and a temperature parameter that controls the randomness of the generated text.model_structure_repr
: Generates a string representation of a PydanticBaseModel
structure. This is useful for understanding what keys a model requires.is_valid_json_for_model
: Checks if a given JSON string is not only valid JSON but also conforms to a given PydanticBaseModel
.generate_text
: Interacts with OpenAI's API to produce text based on a given prompt. Optionally, you can specify an output format (as a PydanticBaseModel
), the number of completions, and a maximum token limit.
Detailed Breakdown
Initialization (__init__)
When you create an instance of LLM
, you pass in your OpenAI API key, the model type, and the temperature. These are stored as instance variables and used later in API calls.
def __init__(self, api_key, model="gpt-3.5-turbo", temperature=0.5):
self.api_key = api_key
self.model = model
self.temperature = temperature
openai.api_key = self.api_key
Model Structure Representation (model_structure_repr)
This method takes a Pydantic BaseModel
as an argument and recursively generates a string representation of the model's structure. It's useful for providing feedback on what kind of JSON to expect.
def model_structure_repr(self, model: Type[BaseModel]) -> str:
fields = model.__fields__
# pdb.set_trace()
field_reprs = []
for name, model_field in fields.items():
description = model_field.field_info.description or "No description"
field_type = model_field.annotation
# If it's a list type
if getattr(field_type, '__origin__', None) == list:
inner_type = field_type.__args__[0]
# Check if the inner type of the list is a BaseModel
if issubclass(inner_type, BaseModel):
inner_repr = self.model_structure_repr(inner_type)
field_reprs.append(f"{name}: [{inner_repr}]({description})")
else:
field_reprs.append(f"{name}: [{inner_type.__name__}]({description})")
# If it's a BaseModel (but not a list)
elif issubclass(field_type, BaseModel):
inner_repr = self.model_structure_repr(field_type)
field_reprs.append(f"{name}: {inner_repr} ({description})")
# For basic types (e.g. str, int, ...)
else:
field_reprs.append(f"{name}: {field_type.__name__} ({description})")
return f"{{{', '.join(field_reprs)}}}"
Validating JSON (is_valid_json_for_model)
This method tries to parse a JSON string and checks whether it fits into a specified Pydantic BaseModel
. It returns a boolean value based on the validation.
def is_valid_json_for_model(self, text: str, model: Type[BaseModel]) -> bool:
"""
Check if a text is valid JSON and if it respects the provided BaseModel.
"""
model.model_config = ConfigDict(strict=True)
try:
parsed_data = json.loads(text)
model(**parsed_data)
return True
except (json.JSONDecodeError, ValidationError) as e:
return False
Text Generation (generate_text)
Here, the magic happens. You send a system message and a user prompt to the OpenAI API. The API returns a response based on the model and temperature you’ve set.
If you specify an output format (a Pydantic BaseModel
), the response from the API must conform to that model. Otherwise, it's ignored. The method also handles rate-limiting by waiting and retrying the API call.
def generate_text(self, prompt, output_format: Optional[BaseModel] = None, n_completions=1, max_tokens=None):
retry_delay = 0.1 # initial delay is 100 milliseconds
valid_responses = []
while len(valid_responses) < n_completions:
try:
system_message = "You are a helpful assistant."
if output_format:
system_message += f" Respond in a json format that contains the following keys: {self.model_structure_repr(output_format)}"
params = {
"model": self.model,
"messages": [
{
"role": "system",
"content": system_message
},
{
"role": "user",
"content": prompt
}
],
"temperature": self.temperature,
"n": n_completions
}
if max_tokens is not None:
params["max_tokens"] = max_tokens
response = openai.ChatCompletion.create(**params)
choices = response["choices"]
responses = [choice["message"]["content"] for choice in choices]
if output_format:
valid_responses.extend([json.loads(res) for res in responses if self.is_valid_json_for_model(res, output_format)])
else:
valid_responses.extend(responses)
except openai.error.RateLimitError as err:
print(f"Hit rate limit. Retrying in {retry_delay} seconds.")
time.sleep(retry_delay)
retry_delay *= 2
except Exception as err:
print(f"Error: {err}")
break
return valid_responses[:n_completions]
Error Handling and Rate Limiting
The code is also designed to handle rate-limiting errors from the API by exponentially increasing the wait time before trying the request again.
That’s a quick overview! This wrapper simplifies several complex tasks into easy-to-use methods, enabling more efficient and convenient interaction with OpenAI’s API.
Full code:
import openai
import time
import json
from pydantic import BaseModel, ValidationError, ConfigDict
from typing import Type, Optional
class LLM:
def __init__(self, api_key, model="gpt-3.5-turbo", temperature=0.5):
self.api_key = api_key
self.model = model
self.temperature = temperature
openai.api_key = self.api_key
def model_structure_repr(self, model: Type[BaseModel]) -> str:
fields = model.__fields__
field_reprs = []
for name, model_field in fields.items():
description = model_field.field_info.description or "No description"
field_type = model_field.annotation
# If it's a list type
if getattr(field_type, '__origin__', None) == list:
inner_type = field_type.__args__[0]
# Check if the inner type of the list is a BaseModel
if issubclass(inner_type, BaseModel):
inner_repr = self.model_structure_repr(inner_type)
field_reprs.append(f"{name}: [{inner_repr}]({description})")
else:
field_reprs.append(f"{name}: [{inner_type.__name__}]({description})")
# If it's a BaseModel (but not a list)
elif issubclass(field_type, BaseModel):
inner_repr = self.model_structure_repr(field_type)
field_reprs.append(f"{name}: {inner_repr} ({description})")
# For basic types (e.g. str, int, ...)
else:
field_reprs.append(f"{name}: {field_type.__name__} ({description})")
return f"{{{', '.join(field_reprs)}}}"
def is_valid_json_for_model(self, text: str, model: Type[BaseModel]) -> bool:
"""
Check if a text is valid JSON and if it respects the provided BaseModel.
"""
model.model_config = ConfigDict(strict=True)
try:
parsed_data = json.loads(text)
model(**parsed_data)
return True
except (json.JSONDecodeError, ValidationError) as e:
return False
def generate_text(self, prompt, output_format: Optional[BaseModel] = None, n_completions=1, max_tokens=None):
retry_delay = 0.1 # initial delay is 100 milliseconds
valid_responses = []
while len(valid_responses) < n_completions:
try:
system_message = "You are a helpful assistant."
if output_format:
system_message += f" Respond in a json format that contains the following keys: {self.model_structure_repr(output_format)}"
params = {
"model": self.model,
"messages": [
{
"role": "system",
"content": system_message
},
{
"role": "user",
"content": prompt
}
],
"temperature": self.temperature,
"n": n_completions
}
if max_tokens is not None:
params["max_tokens"] = max_tokens
response = openai.ChatCompletion.create(**params)
choices = response["choices"]
responses = [choice["message"]["content"] for choice in choices]
if output_format:
valid_responses.extend([json.loads(res) for res in responses if self.is_valid_json_for_model(res, output_format)])
else:
valid_responses.extend(responses)
except openai.error.RateLimitError as err:
print(f"Hit rate limit. Retrying in {retry_delay} seconds.")
time.sleep(retry_delay)
retry_delay *= 2
except Exception as err:
print(f"Error: {err}")
break
return valid_responses[:n_completions]