feat(ai): use open ai to design the harness for you

This commit is contained in:
gsubmarine 2025-01-12 15:39:39 -06:00
parent e8c482e94e
commit ab9c0a75e2
4 changed files with 128 additions and 9 deletions

View File

@ -88,6 +88,13 @@ Read the [syntax description](syntax.md) to learn about WireViz' features and ho
See the [tutorial page](../tutorial/readme.md) for sample code, as well as the [example gallery](../examples/readme.md) to see more of what WireViz can do.
### Demo 03
wireviz -q "design a wire harness that connects two ethernet devices with d38999 connectors on both sides"
Use open AI to design the harness for you.
Define your openAI API key by setting an environment variable.
`OPENAI_KEY={}`
## Usage

89
src/wireviz/openai.py Normal file
View File

@ -0,0 +1,89 @@
# Hit the chat gpt api to ask for harness design
from openai import OpenAI
import os
import subprocess
# set api key to environment variable OPENAI_KEY
import json
import sys
PROMPT_CONTEXT = '''
You are a helpful assistant.
only output a yml file with pinout, notes, pin numbers. Please do not output any information other than the yml. It should follow this format:
connectors:
J1:
pinlabels: ["TX+", "TX-", "RX+", "RX-","BI_DD+", "BI_DD-", "BI_DC+", "BI_DC-"]
type: "D38999"
subtype: receptacle
J2:
pinlabels: ["TX+", "TX-", "RX+", "RX-","BI_DD+", "BI_DD-", "BI_DC+", "BI_DC-"]
type: "D38999"
subtype: receptacle
cables:
W1:
wirecount: 8
length: 2
gauge: 24 AWG
show_equiv: true
color_code: T568A
shield: true # cable shielding included
notes: Connect the cable shield to the backshell for EMI grounding
connections:
- # Connect twisted pairs for Ethernet
- J1: [1-8]
- W1: [1-8]
- J2: [1-8]
'''
# Query the OpenAI API
def queryGPT(query, model="gpt-4"):
"""
Takes a query and returns the response from the OpenAI API.
Parameters:
query (str): The input query or prompt.
model (str): The model to use for the query (default is "gpt-4").
Returns:
str: The response from the OpenAI API.
"""
API_KEY= os.getenv("OPENAI_KEY")
client = OpenAI(api_key=API_KEY)
query = f"{PROMPT_CONTEXT}\n\n\n{query}"
try:
# Make an API call
response = client.chat.completions.create(model=model,
messages=[
{"role": "system", "content": PROMPT_CONTEXT},
{"role": "user", "content": query}
])
# Extract the response content
resp = response.choices[0].message.content.strip()
return clean_gpt_response(resp)
except Exception as e:
return f"An error occurred: {e}"
def clean_gpt_response(response):
"""
Cleans the GPT response by removing any leading or trailing whitespace.
Parameters:
response (str): The GPT response to clean.
Returns:
str: The cleaned GPT response.
"""
# remove everything before the word connectors. and remove any trailing ``
response = response[response.find("connectors"):]
cleaned_output = response.replace("`", "")
return cleaned_output

View File

@ -89,6 +89,7 @@ def parse(
raise Exception("No output formats or return types specified")
yaml_data, yaml_file = _get_yaml_data_and_path(inp)
print(yaml_data)
if not isinstance(yaml_data, dict):
raise TypeError(
f"Expected a dict as top-level YAML input, but got: {type(yaml_data)}"
@ -380,7 +381,7 @@ def parse(
if "additional_bom_items" in yaml_data:
for line in yaml_data["additional_bom_items"]:
harness.add_bom_item(line)
print(output_formats)
if output_formats:
harness.output(filename=output_file, fmt=output_formats, view=False)

View File

@ -12,6 +12,7 @@ if __name__ == "__main__":
import wireviz.wireviz as wv
from wireviz import APP_NAME, __version__
from wireviz.wv_helper import file_read_text
from wireviz.openai import queryGPT
format_codes = {
# "c": "csv",
@ -71,7 +72,18 @@ epilog += ", ".join([f"{key} ({value.upper()})" for key, value in format_codes.i
default=False,
help=f"Output {APP_NAME} version and exit.",
)
def wireviz(file, format, prepend, output_dir, output_name, version):
# add an option to input a chatgpt query
@click.option(
"-q",
"--query",
default=None,
type=str,
help="Query to input to chatgpt.",
)
def wireviz(file, format, prepend, output_dir, output_name, version, query):
"""
Parses the provided FILE and generates the specified outputs.
"""
@ -80,13 +92,24 @@ def wireviz(file, format, prepend, output_dir, output_name, version):
if version:
return # print version number only and exit
# get list of files
try:
_ = iter(file)
except TypeError:
filepaths = [file]
# check query
if query:
gpt_response = queryGPT(query)
yml_file= Path("/Users/gouthamsubramanian/harness.yml")
with open(yml_file, "w") as file:
file.write(gpt_response)
print("Response written to", yml_file)
filepaths = [yml_file]
else:
filepaths = list(file)
try:
_ = iter(file)
except TypeError:
filepaths = [file]
else:
filepaths = list(file)
# determine output formats
output_formats = []
@ -115,7 +138,6 @@ def wireviz(file, format, prepend, output_dir, output_name, version):
else:
prepend_input = ""
# run WireVIz on each input file
for file in filepaths:
file = Path(file)
if not file.exists():