Skip to content
Snippets Groups Projects
Commit 783fe058 authored by Julian's avatar Julian
Browse files

removed flake8 due to errors, simplified client

parent f50deb01
No related branches found
No related tags found
No related merge requests found
...@@ -9,14 +9,12 @@ RUN conda env update -q -f /tmp/environment.yml && \ ...@@ -9,14 +9,12 @@ RUN conda env update -q -f /tmp/environment.yml && \
conda env export -n "root" && \ conda env export -n "root" && \
jupyter lab build jupyter lab build
RUN pip3 install --upgrade pip
RUN pip install jupyterlab_flake8
COPY dash_proxy /tmp/dash_proxy/ COPY dash_proxy /tmp/dash_proxy/
RUN pip install /tmp/dash_proxy/ RUN pip install /tmp/dash_proxy/
COPY llm_utils /llm_utils/ COPY llm_utils /llm_utils/
RUN pip install /llm_utils/ RUN pip install /llm_utils/
ENV CONFIG_PATH=/home/jovyan/config.txt
COPY app /dash/app/ COPY app /dash/app/
RUN chown -R jovyan /dash/app/ RUN chown -R jovyan /dash/app/
......
import os
from datetime import datetime from datetime import datetime
from dash import ( from dash import (
...@@ -10,7 +11,7 @@ from dash.dependencies import ( ...@@ -10,7 +11,7 @@ from dash.dependencies import (
State State
) )
from llm_utils.client import ChatGPT from llm_utils.client import ChatGPT, get_openai_client
def format_chat_messages(chat_history): def format_chat_messages(chat_history):
...@@ -24,8 +25,15 @@ def format_chat_messages(chat_history): ...@@ -24,8 +25,15 @@ def format_chat_messages(chat_history):
def register_callbacks(app: Dash): def register_callbacks(app: Dash):
model="gpt4"
chat_gpt = ChatGPT(model="gpt4") client = get_openai_client(
model=model,
config_path=os.environ.get("CONFIG_PATH")
)
chat_gpt = ChatGPT(
client=client,
model="gpt4"
)
@app.callback( @app.callback(
[Output('chat-container', 'children'), [Output('chat-container', 'children'),
......
import os import os
import logging
from openai import AzureOpenAI from openai import AzureOpenAI
from dotenv import load_dotenv from dotenv import load_dotenv
from enum import Enum from enum import Enum
try:
found_dotenv = load_dotenv(
"/home/jovyan/config.txt",
override=True
)
except ValueError:
logging.warn("Could not detect config.txt in /home/jovyan/. Searching in current folder ...")
found_dotenv = load_dotenv(
"config.txt",
override=True)
if not found_dotenv:
raise ValueError("Could not detect config.txt in /home/jovyan/.")
AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION")
class OpenAIModels(Enum): class OpenAIModels(Enum):
GPT_3 = "gpt3" GPT_3 = "gpt3"
...@@ -33,13 +15,25 @@ class OpenAIModels(Enum): ...@@ -33,13 +15,25 @@ class OpenAIModels(Enum):
return [member.value for member in cls] return [member.value for member in cls]
def get_openai_client(model: str) -> AzureOpenAI: def get_openai_client(
model: str,
config_path: str
) -> AzureOpenAI:
if not model in OpenAIModels.get_all_values(): if not model in OpenAIModels.get_all_values():
raise ValueError(f"<model> needs to be one of {OpenAIModels.get_all_values()}.") raise ValueError(f"<model> needs to be one of {OpenAIModels.get_all_values()}.")
load_dotenv(
dotenv_path=config_path,
override=True
)
AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION")
if any(p is None for p in (AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_KEY, OPENAI_API_VERSION)): if any(p is None for p in (AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_KEY, OPENAI_API_VERSION)):
raise ValueError( raise ValueError(
f"""None of the following parameters can be none: f"""None of the following parameters can be None:
AZURE_OPENAI_API_KEY: {AZURE_OPENAI_API_KEY}, AZURE_OPENAI_API_KEY: {AZURE_OPENAI_API_KEY},
AZURE_OPENAI_API_KEY: {AZURE_OPENAI_API_KEY}, AZURE_OPENAI_API_KEY: {AZURE_OPENAI_API_KEY},
OPENAI_API_VERSION: {OPENAI_API_VERSION} OPENAI_API_VERSION: {OPENAI_API_VERSION}
...@@ -56,9 +50,9 @@ def get_openai_client(model: str) -> AzureOpenAI: ...@@ -56,9 +50,9 @@ def get_openai_client(model: str) -> AzureOpenAI:
class ChatGPT: class ChatGPT:
def __init__(self, model="gpt4"): def __init__(self, client: AzureOpenAI, model: str):
self.model = model self.model = model
self.client = get_openai_client(model=model) self.client = client
self.messages = [] self.messages = []
def chat_with_gpt(self, user_input: str): def chat_with_gpt(self, user_input: str):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment