Features that allow user to config LLM by config file. (#174)

This commit is contained in:
codescracker
2023-10-08 19:35:22 +08:00
committed by GitHub
parent d9bab899fe
commit 4bbd2c7446
17 changed files with 148 additions and 89 deletions

View File

@@ -0,0 +1,53 @@
# -*- coding:utf-8 -*-
import os
import configparser
from util.logging_utils import logger
def type_convert(input_str: str):
try:
return eval(input_str)
except:
return input_str
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
config_file = "run_config.ini"
config_path = os.path.join(PROJECT_DIR_PATH, config_file)
config = configparser.ConfigParser()
config.read(config_path)
llm_parser_section_name = "LLMParser"
LLMPARSER_HOST = config.get(llm_parser_section_name, 'LLMPARSER_HOST')
LLMPARSER_PORT = int(config.get(llm_parser_section_name, 'LLMPARSER_PORT'))
chroma_db_section_name = "ChromaDB"
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME')
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM'))
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
text2vec_section_name = "Text2Vec"
HF_TEXT2VEC_MODEL_NAME = config.get(text2vec_section_name, 'HF_TEXT2VEC_MODEL_NAME')
llm_provider_section_name = "LLMProvider"
LLM_PROVIDER_NAME = config.get(llm_provider_section_name, 'LLM_PROVIDER_NAME')
llm_model_section_name = "LLMModel"
llm_config_dict = {}
for option in config.options(llm_model_section_name):
llm_config_dict[option] = type_convert(config.get(llm_model_section_name, option))
if __name__ == "__main__":
logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST)
logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT)
logger.info("llm_config_dict: ", llm_config_dict)
logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT)

View File

@@ -0,0 +1,24 @@
[LLMParser]
LLMPARSER_HOST = 127.0.0.1
LLMPARSER_PORT = 9092
[ChromaDB]
CHROMA_DB_PERSIST_DIR = chm_db
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
TEXT2DSL_COLLECTION_NAME = text2dsl_collection
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSL_IS_SHORTCUT = False
[Text2Vec]
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese
[LLMProvider]
LLM_PROVIDER_NAME = openai
[LLMModel]
MODEL_NAME = gpt-3.5-turbo-16k
OPENAI_API_KEY = YOUR_API_KEY
TEMPERATURE = 0.0

View File

@@ -1,31 +0,0 @@
# -*- coding:utf-8 -*-
import os
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
LLMPARSER_HOST = "127.0.0.1"
LLMPARSER_PORT = 9092
MODEL_NAME = "gpt-3.5-turbo-16k"
OPENAI_API_KEY = "YOUR_API_KEY"
OPENAI_API_BASE = ""
TEMPERATURE = 0.0
CHROMA_DB_PERSIST_DIR = "chm_db"
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
SOLVED_QUERY_COLLECTION_NAME = "solved_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSL_IS_SHORTCUT = False
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese"
if __name__ == "__main__":
logger.info("PROJECT_DIR_PATH: {}", PROJECT_DIR_PATH)
logger.info("EMB_MODEL_PATH: {}", HF_TEXT2VEC_MODEL_NAME)
logger.info("CHROMA_DB_PERSIST_PATH: {}", CHROMA_DB_PERSIST_PATH)
logger.info("LLMPARSER_HOST: {}", LLMPARSER_HOST)
logger.info("LLMPARSER_PORT: {}", LLMPARSER_PORT)

View File

@@ -5,7 +5,7 @@ import re
import sys
from typing import Any, List, Mapping, Union
from loguru import logger
from util.logging_utils import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

View File

@@ -7,6 +7,7 @@ from typing import List
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from chromadb.api import Collection
from preset_query_db import (
@@ -18,7 +19,7 @@ from preset_query_db import (
from util.text2vec import Text2VecEmbeddingFunction
from run_config import PRESET_QUERY_COLLECTION_NAME
from config.config_parse import PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import client

View File

@@ -5,18 +5,18 @@ import sys
import uuid
from typing import Any, List, Mapping, Optional, Union
from loguru import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
import chromadb
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from util.text2vec import Text2VecEmbeddingFunction
from run_config import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import (client,
get_chroma_collection_size, query_chroma_collection,
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,

View File

@@ -3,18 +3,18 @@ import os
import sys
from typing import List, Mapping
from loguru import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from langchain.vectorstores import Chroma
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from few_shot_example.sql_exampler import examplars as sql_examplars
from util.text2vec import hg_embedding
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
def reload_sql_example_collection(

View File

@@ -6,12 +6,14 @@ from typing import List, Mapping
import requests
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT
from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT,
LLMPARSER_HOST, LLMPARSER_PORT)
from few_shot_example.sql_exampler import examplars as sql_examplars
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
from util.logging_utils import logger
def text2dsl_setting_update(

View File

@@ -1,6 +1,15 @@
# -*- coding:utf-8 -*-
import re
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
def schema_link_parse(schema_link_output):
try:

View File

@@ -6,6 +6,8 @@ from typing import List, Mapping
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from langchain.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector

View File

@@ -2,11 +2,11 @@ import os
import sys
from typing import List, Union, Mapping
from loguru import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from sql.prompt_maker import (
schema_linking_exampler,
sql_exampler,
@@ -24,7 +24,7 @@ from sql.output_parser import (
)
from util.llm_instance import llm
from run_config import TEXT2DSL_IS_SHORTCUT
from config.config_parse import TEXT2DSL_IS_SHORTCUT
class Text2DSLAgent(object):
@@ -65,10 +65,10 @@ class Text2DSLAgent(object):
linking: Union[List[Mapping[str, str]], None] = None,
):
logger.info("query_text: {}", query_text)
logger.info("schema: {}", schema)
logger.info("current_date: {}", current_date)
logger.info("prior_schema_links: {}", linking)
logger.info("query_text: {}".format(query_text))
logger.info("schema: {}".format(schema))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(linking))
if linking is not None:
prior_schema_links = {
@@ -87,7 +87,7 @@ class Text2DSLAgent(object):
prior_schema_links,
self.sql_example_selector,
)
logger.info("schema_linking_prompt-> {}", schema_linking_prompt)
logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt))
schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output)
@@ -98,7 +98,7 @@ class Text2DSLAgent(object):
current_date,
self.sql_example_selector,
)
logger.info("sql_prompt->", sql_prompt)
logger.info("sql_prompt-> {}".format(sql_prompt))
sql_output = self.llm(sql_prompt)
resp = dict()
@@ -113,7 +113,7 @@ class Text2DSLAgent(object):
resp["sqlOutput"] = sql_output
logger.info("resp: ", resp)
logger.info("resp: {}".format(resp))
return resp
@@ -125,10 +125,10 @@ class Text2DSLAgent(object):
linking: Union[List[Mapping[str, str]], None] = None,
):
logger.info("query_text: ", query_text)
logger.info("schema: ", schema)
logger.info("current_date: ", current_date)
logger.info("prior_schema_links: ", linking)
logger.info("query_text: {}".format(query_text))
logger.info("schema: {}".format(schema))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(linking))
if linking is not None:
prior_schema_links = {
@@ -148,7 +148,7 @@ class Text2DSLAgent(object):
prior_schema_links,
self.sql_example_selector,
)
logger.info("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt))
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
schema_linking_str = self.combo_schema_link_parse(
@@ -167,7 +167,7 @@ class Text2DSLAgent(object):
resp["schemaLinkStr"] = schema_linking_str
resp["sqlOutput"] = sql_str
logger.info("resp: ", resp)
logger.info("resp: {}".format(resp))
return resp

View File

@@ -4,10 +4,6 @@ import sys
import uvicorn
from util.logging_utils import init_logger
init_logger()
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
@@ -15,7 +11,7 @@ from typing import Any, List, Mapping
from fastapi import FastAPI, HTTPException
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT
from services_router import (query2sql_service, preset_query_service,
solved_query_service, plugin_call_service)

View File

@@ -4,9 +4,15 @@ from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb.api import Collection
from chromadb.config import Settings
from loguru import logger
from run_config import CHROMA_DB_PERSIST_PATH
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from config.config_parse import CHROMA_DB_PERSIST_PATH
client = chromadb.Client(
Settings(
@@ -75,7 +81,7 @@ def query_chroma_collection(collection:Collection, query_texts:List[str],
else:
outer_filter = None
logger.info('outer_filter: ', outer_filter)
print('outer_filter: ', outer_filter)
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
return res

View File

@@ -1,12 +1,21 @@
# -*- coding:utf-8 -*-
from langchain.llms import OpenAI
from langchain import llms
from run_config import *
from util.stringutils import *
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
llm = OpenAI(
model_name=MODEL_NAME,
openai_api_key=OPENAI_API_KEY,
openai_api_base=default_if_blank(OPENAI_API_BASE),
temperature=TEMPERATURE,
)
from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict
def get_llm_provider(llm_provider_name: str, llm_config_dict: dict):
if llm_provider_name in llms.type_to_cls_dict:
llm_provider = llms.type_to_cls_dict[llm_provider_name]
llm = llm_provider(**llm_config_dict)
return llm
else:
raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name))
llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict)

View File

@@ -1,6 +1,4 @@
from loguru import logger
import sys
def init_logger():
logger.remove()
logger.add("llmparser.info.log")
# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")

View File

@@ -1,10 +0,0 @@
def is_blank(s: str) -> bool:
return not (s and s.strip())
def is_not_blank(s: str) -> bool:
return not is_blank(s)
def default_if_blank(s: str, default: str = None) -> str:
return s if is_not_blank(s) else default

View File

@@ -4,7 +4,7 @@ from typing import List
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from langchain.embeddings import HuggingFaceEmbeddings
from run_config import HF_TEXT2VEC_MODEL_NAME
from config.config_parse import HF_TEXT2VEC_MODEL_NAME
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)