feat(ai): use open ai to design the harness for you
This commit is contained in:
parent
e8c482e94e
commit
ab9c0a75e2
@ -88,6 +88,13 @@ Read the [syntax description](syntax.md) to learn about WireViz' features and ho
|
|||||||
|
|
||||||
See the [tutorial page](../tutorial/readme.md) for sample code, as well as the [example gallery](../examples/readme.md) to see more of what WireViz can do.
|
See the [tutorial page](../tutorial/readme.md) for sample code, as well as the [example gallery](../examples/readme.md) to see more of what WireViz can do.
|
||||||
|
|
||||||
|
### Demo 03
|
||||||
|
wireviz -q "design a wire harness that connects two ethernet devices with d38999 connectors on both sides"
|
||||||
|
|
||||||
|
Use open AI to design the harness for you.
|
||||||
|
|
||||||
|
Define your openAI API key by setting an environment variable.
|
||||||
|
`OPENAI_KEY={}`
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
|||||||
89
src/wireviz/openai.py
Normal file
89
src/wireviz/openai.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# Hit the chat gpt api to ask for harness design
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
# set api key to environment variable OPENAI_KEY
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
PROMPT_CONTEXT = '''
|
||||||
|
You are a helpful assistant.
|
||||||
|
only output a yml file with pinout, notes, pin numbers. Please do not output any information other than the yml. It should follow this format:
|
||||||
|
|
||||||
|
connectors:
|
||||||
|
J1:
|
||||||
|
pinlabels: ["TX+", "TX-", "RX+", "RX-","BI_DD+", "BI_DD-", "BI_DC+", "BI_DC-"]
|
||||||
|
type: "D38999"
|
||||||
|
subtype: receptacle
|
||||||
|
J2:
|
||||||
|
pinlabels: ["TX+", "TX-", "RX+", "RX-","BI_DD+", "BI_DD-", "BI_DC+", "BI_DC-"]
|
||||||
|
type: "D38999"
|
||||||
|
subtype: receptacle
|
||||||
|
|
||||||
|
|
||||||
|
cables:
|
||||||
|
W1:
|
||||||
|
wirecount: 8
|
||||||
|
length: 2
|
||||||
|
gauge: 24 AWG
|
||||||
|
show_equiv: true
|
||||||
|
color_code: T568A
|
||||||
|
shield: true # cable shielding included
|
||||||
|
notes: Connect the cable shield to the backshell for EMI grounding
|
||||||
|
|
||||||
|
connections:
|
||||||
|
- # Connect twisted pairs for Ethernet
|
||||||
|
- J1: [1-8]
|
||||||
|
- W1: [1-8]
|
||||||
|
- J2: [1-8]
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Query the OpenAI API
|
||||||
|
def queryGPT(query, model="gpt-4"):
|
||||||
|
"""
|
||||||
|
Takes a query and returns the response from the OpenAI API.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query (str): The input query or prompt.
|
||||||
|
model (str): The model to use for the query (default is "gpt-4").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the OpenAI API.
|
||||||
|
"""
|
||||||
|
API_KEY= os.getenv("OPENAI_KEY")
|
||||||
|
client = OpenAI(api_key=API_KEY)
|
||||||
|
|
||||||
|
query = f"{PROMPT_CONTEXT}\n\n\n{query}"
|
||||||
|
try:
|
||||||
|
# Make an API call
|
||||||
|
response = client.chat.completions.create(model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": PROMPT_CONTEXT},
|
||||||
|
{"role": "user", "content": query}
|
||||||
|
])
|
||||||
|
# Extract the response content
|
||||||
|
resp = response.choices[0].message.content.strip()
|
||||||
|
return clean_gpt_response(resp)
|
||||||
|
except Exception as e:
|
||||||
|
return f"An error occurred: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def clean_gpt_response(response):
|
||||||
|
"""
|
||||||
|
Cleans the GPT response by removing any leading or trailing whitespace.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
response (str): The GPT response to clean.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The cleaned GPT response.
|
||||||
|
"""
|
||||||
|
# remove everything before the word connectors. and remove any trailing ``
|
||||||
|
response = response[response.find("connectors"):]
|
||||||
|
cleaned_output = response.replace("`", "")
|
||||||
|
|
||||||
|
return cleaned_output
|
||||||
@ -89,6 +89,7 @@ def parse(
|
|||||||
raise Exception("No output formats or return types specified")
|
raise Exception("No output formats or return types specified")
|
||||||
|
|
||||||
yaml_data, yaml_file = _get_yaml_data_and_path(inp)
|
yaml_data, yaml_file = _get_yaml_data_and_path(inp)
|
||||||
|
print(yaml_data)
|
||||||
if not isinstance(yaml_data, dict):
|
if not isinstance(yaml_data, dict):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected a dict as top-level YAML input, but got: {type(yaml_data)}"
|
f"Expected a dict as top-level YAML input, but got: {type(yaml_data)}"
|
||||||
@ -380,7 +381,7 @@ def parse(
|
|||||||
if "additional_bom_items" in yaml_data:
|
if "additional_bom_items" in yaml_data:
|
||||||
for line in yaml_data["additional_bom_items"]:
|
for line in yaml_data["additional_bom_items"]:
|
||||||
harness.add_bom_item(line)
|
harness.add_bom_item(line)
|
||||||
|
print(output_formats)
|
||||||
if output_formats:
|
if output_formats:
|
||||||
harness.output(filename=output_file, fmt=output_formats, view=False)
|
harness.output(filename=output_file, fmt=output_formats, view=False)
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ if __name__ == "__main__":
|
|||||||
import wireviz.wireviz as wv
|
import wireviz.wireviz as wv
|
||||||
from wireviz import APP_NAME, __version__
|
from wireviz import APP_NAME, __version__
|
||||||
from wireviz.wv_helper import file_read_text
|
from wireviz.wv_helper import file_read_text
|
||||||
|
from wireviz.openai import queryGPT
|
||||||
|
|
||||||
format_codes = {
|
format_codes = {
|
||||||
# "c": "csv",
|
# "c": "csv",
|
||||||
@ -71,7 +72,18 @@ epilog += ", ".join([f"{key} ({value.upper()})" for key, value in format_codes.i
|
|||||||
default=False,
|
default=False,
|
||||||
help=f"Output {APP_NAME} version and exit.",
|
help=f"Output {APP_NAME} version and exit.",
|
||||||
)
|
)
|
||||||
def wireviz(file, format, prepend, output_dir, output_name, version):
|
|
||||||
|
# add an option to input a chatgpt query
|
||||||
|
@click.option(
|
||||||
|
"-q",
|
||||||
|
"--query",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Query to input to chatgpt.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wireviz(file, format, prepend, output_dir, output_name, version, query):
|
||||||
"""
|
"""
|
||||||
Parses the provided FILE and generates the specified outputs.
|
Parses the provided FILE and generates the specified outputs.
|
||||||
"""
|
"""
|
||||||
@ -80,13 +92,24 @@ def wireviz(file, format, prepend, output_dir, output_name, version):
|
|||||||
if version:
|
if version:
|
||||||
return # print version number only and exit
|
return # print version number only and exit
|
||||||
|
|
||||||
# get list of files
|
# check query
|
||||||
try:
|
if query:
|
||||||
_ = iter(file)
|
gpt_response = queryGPT(query)
|
||||||
except TypeError:
|
yml_file= Path("/Users/gouthamsubramanian/harness.yml")
|
||||||
filepaths = [file]
|
with open(yml_file, "w") as file:
|
||||||
|
file.write(gpt_response)
|
||||||
|
print("Response written to", yml_file)
|
||||||
|
filepaths = [yml_file]
|
||||||
else:
|
else:
|
||||||
filepaths = list(file)
|
try:
|
||||||
|
_ = iter(file)
|
||||||
|
except TypeError:
|
||||||
|
filepaths = [file]
|
||||||
|
else:
|
||||||
|
filepaths = list(file)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# determine output formats
|
# determine output formats
|
||||||
output_formats = []
|
output_formats = []
|
||||||
@ -115,7 +138,6 @@ def wireviz(file, format, prepend, output_dir, output_name, version):
|
|||||||
else:
|
else:
|
||||||
prepend_input = ""
|
prepend_input = ""
|
||||||
|
|
||||||
# run WireVIz on each input file
|
|
||||||
for file in filepaths:
|
for file in filepaths:
|
||||||
file = Path(file)
|
file = Path(file)
|
||||||
if not file.exists():
|
if not file.exists():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user