Features that allow user to config LLM by config file. (#174)

This commit is contained in:
codescracker
2023-10-08 19:35:22 +08:00
committed by GitHub
parent d9bab899fe
commit 4bbd2c7446
17 changed files with 148 additions and 89 deletions

View File

@@ -0,0 +1,53 @@
# -*- coding:utf-8 -*-
import os
import configparser
from util.logging_utils import logger
def type_convert(input_str: str):
try:
return eval(input_str)
except:
return input_str
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
config_file = "run_config.ini"
config_path = os.path.join(PROJECT_DIR_PATH, config_file)
config = configparser.ConfigParser()
config.read(config_path)
llm_parser_section_name = "LLMParser"
LLMPARSER_HOST = config.get(llm_parser_section_name, 'LLMPARSER_HOST')
LLMPARSER_PORT = int(config.get(llm_parser_section_name, 'LLMPARSER_PORT'))
chroma_db_section_name = "ChromaDB"
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME')
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM'))
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
text2vec_section_name = "Text2Vec"
HF_TEXT2VEC_MODEL_NAME = config.get(text2vec_section_name, 'HF_TEXT2VEC_MODEL_NAME')
llm_provider_section_name = "LLMProvider"
LLM_PROVIDER_NAME = config.get(llm_provider_section_name, 'LLM_PROVIDER_NAME')
llm_model_section_name = "LLMModel"
llm_config_dict = {}
for option in config.options(llm_model_section_name):
llm_config_dict[option] = type_convert(config.get(llm_model_section_name, option))
if __name__ == "__main__":
logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST)
logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT)
logger.info("llm_config_dict: ", llm_config_dict)
logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT)

View File

@@ -0,0 +1,24 @@
[LLMParser]
LLMPARSER_HOST = 127.0.0.1
LLMPARSER_PORT = 9092
[ChromaDB]
CHROMA_DB_PERSIST_DIR = chm_db
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
TEXT2DSL_COLLECTION_NAME = text2dsl_collection
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSL_IS_SHORTCUT = False
[Text2Vec]
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese
[LLMProvider]
LLM_PROVIDER_NAME = openai
[LLMModel]
MODEL_NAME = gpt-3.5-turbo-16k
OPENAI_API_KEY = YOUR_API_KEY
TEMPERATURE = 0.0

View File

@@ -1,31 +0,0 @@
# -*- coding:utf-8 -*-
import os
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
LLMPARSER_HOST = "127.0.0.1"
LLMPARSER_PORT = 9092
MODEL_NAME = "gpt-3.5-turbo-16k"
OPENAI_API_KEY = "YOUR_API_KEY"
OPENAI_API_BASE = ""
TEMPERATURE = 0.0
CHROMA_DB_PERSIST_DIR = "chm_db"
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
SOLVED_QUERY_COLLECTION_NAME = "solved_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSL_IS_SHORTCUT = False
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese"
if __name__ == "__main__":
logger.info("PROJECT_DIR_PATH: {}", PROJECT_DIR_PATH)
logger.info("EMB_MODEL_PATH: {}", HF_TEXT2VEC_MODEL_NAME)
logger.info("CHROMA_DB_PERSIST_PATH: {}", CHROMA_DB_PERSIST_PATH)
logger.info("LLMPARSER_HOST: {}", LLMPARSER_HOST)
logger.info("LLMPARSER_PORT: {}", LLMPARSER_PORT)

View File

@@ -5,7 +5,7 @@ import re
import sys import sys
from typing import Any, List, Mapping, Union from typing import Any, List, Mapping, Union
from loguru import logger from util.logging_utils import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))

View File

@@ -7,6 +7,7 @@ from typing import List
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from chromadb.api import Collection from chromadb.api import Collection
from preset_query_db import ( from preset_query_db import (
@@ -18,7 +19,7 @@ from preset_query_db import (
from util.text2vec import Text2VecEmbeddingFunction from util.text2vec import Text2VecEmbeddingFunction
from run_config import PRESET_QUERY_COLLECTION_NAME from config.config_parse import PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import client from util.chromadb_instance import client

View File

@@ -5,18 +5,18 @@ import sys
import uuid import uuid
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Union
from loguru import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings from chromadb.api import Collection, Documents, Embeddings
from util.text2vec import Text2VecEmbeddingFunction from util.text2vec import Text2VecEmbeddingFunction
from run_config import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import (client, from util.chromadb_instance import (client,
get_chroma_collection_size, query_chroma_collection, get_chroma_collection_size, query_chroma_collection,
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format, parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,

View File

@@ -3,18 +3,18 @@ import os
import sys import sys
from typing import List, Mapping from typing import List, Mapping
from loguru import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from few_shot_example.sql_exampler import examplars as sql_examplars from few_shot_example.sql_exampler import examplars as sql_examplars
from util.text2vec import hg_embedding from util.text2vec import hg_embedding
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
def reload_sql_example_collection( def reload_sql_example_collection(

View File

@@ -6,12 +6,14 @@ from typing import List, Mapping
import requests import requests
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT,
LLMPARSER_HOST, LLMPARSER_PORT)
from few_shot_example.sql_exampler import examplars as sql_examplars from few_shot_example.sql_exampler import examplars as sql_examplars
from run_config import LLMPARSER_HOST, LLMPARSER_PORT from util.logging_utils import logger
def text2dsl_setting_update( def text2dsl_setting_update(

View File

@@ -1,6 +1,15 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import re import re
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
def schema_link_parse(schema_link_output): def schema_link_parse(schema_link_output):
try: try:

View File

@@ -6,6 +6,8 @@ from typing import List, Mapping
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector from langchain.prompts.example_selector import SemanticSimilarityExampleSelector

View File

@@ -2,11 +2,11 @@ import os
import sys import sys
from typing import List, Union, Mapping from typing import List, Union, Mapping
from loguru import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from sql.prompt_maker import ( from sql.prompt_maker import (
schema_linking_exampler, schema_linking_exampler,
sql_exampler, sql_exampler,
@@ -24,7 +24,7 @@ from sql.output_parser import (
) )
from util.llm_instance import llm from util.llm_instance import llm
from run_config import TEXT2DSL_IS_SHORTCUT from config.config_parse import TEXT2DSL_IS_SHORTCUT
class Text2DSLAgent(object): class Text2DSLAgent(object):
@@ -65,10 +65,10 @@ class Text2DSLAgent(object):
linking: Union[List[Mapping[str, str]], None] = None, linking: Union[List[Mapping[str, str]], None] = None,
): ):
logger.info("query_text: {}", query_text) logger.info("query_text: {}".format(query_text))
logger.info("schema: {}", schema) logger.info("schema: {}".format(schema))
logger.info("current_date: {}", current_date) logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}", linking) logger.info("prior_schema_links: {}".format(linking))
if linking is not None: if linking is not None:
prior_schema_links = { prior_schema_links = {
@@ -87,7 +87,7 @@ class Text2DSLAgent(object):
prior_schema_links, prior_schema_links,
self.sql_example_selector, self.sql_example_selector,
) )
logger.info("schema_linking_prompt-> {}", schema_linking_prompt) logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt))
schema_link_output = self.llm(schema_linking_prompt) schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output) schema_link_str = self.schema_link_parse(schema_link_output)
@@ -98,7 +98,7 @@ class Text2DSLAgent(object):
current_date, current_date,
self.sql_example_selector, self.sql_example_selector,
) )
logger.info("sql_prompt->", sql_prompt) logger.info("sql_prompt-> {}".format(sql_prompt))
sql_output = self.llm(sql_prompt) sql_output = self.llm(sql_prompt)
resp = dict() resp = dict()
@@ -113,7 +113,7 @@ class Text2DSLAgent(object):
resp["sqlOutput"] = sql_output resp["sqlOutput"] = sql_output
logger.info("resp: ", resp) logger.info("resp: {}".format(resp))
return resp return resp
@@ -125,10 +125,10 @@ class Text2DSLAgent(object):
linking: Union[List[Mapping[str, str]], None] = None, linking: Union[List[Mapping[str, str]], None] = None,
): ):
logger.info("query_text: ", query_text) logger.info("query_text: {}".format(query_text))
logger.info("schema: ", schema) logger.info("schema: {}".format(schema))
logger.info("current_date: ", current_date) logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: ", linking) logger.info("prior_schema_links: {}".format(linking))
if linking is not None: if linking is not None:
prior_schema_links = { prior_schema_links = {
@@ -148,7 +148,7 @@ class Text2DSLAgent(object):
prior_schema_links, prior_schema_links,
self.sql_example_selector, self.sql_example_selector,
) )
logger.info("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt) logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt))
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
schema_linking_str = self.combo_schema_link_parse( schema_linking_str = self.combo_schema_link_parse(
@@ -167,7 +167,7 @@ class Text2DSLAgent(object):
resp["schemaLinkStr"] = schema_linking_str resp["schemaLinkStr"] = schema_linking_str
resp["sqlOutput"] = sql_str resp["sqlOutput"] = sql_str
logger.info("resp: ", resp) logger.info("resp: {}".format(resp))
return resp return resp

View File

@@ -4,10 +4,6 @@ import sys
import uvicorn import uvicorn
from util.logging_utils import init_logger
init_logger()
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
@@ -15,7 +11,7 @@ from typing import Any, List, Mapping
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from run_config import LLMPARSER_HOST, LLMPARSER_PORT from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT
from services_router import (query2sql_service, preset_query_service, from services_router import (query2sql_service, preset_query_service,
solved_query_service, plugin_call_service) solved_query_service, plugin_call_service)

View File

@@ -4,9 +4,15 @@ from typing import Any, List, Mapping, Optional, Union
import chromadb import chromadb
from chromadb.api import Collection from chromadb.api import Collection
from chromadb.config import Settings from chromadb.config import Settings
from loguru import logger
from run_config import CHROMA_DB_PERSIST_PATH import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from config.config_parse import CHROMA_DB_PERSIST_PATH
client = chromadb.Client( client = chromadb.Client(
Settings( Settings(
@@ -75,7 +81,7 @@ def query_chroma_collection(collection:Collection, query_texts:List[str],
else: else:
outer_filter = None outer_filter = None
logger.info('outer_filter: ', outer_filter) print('outer_filter: ', outer_filter)
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter) res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
return res return res

View File

@@ -1,12 +1,21 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from langchain.llms import OpenAI from langchain import llms
from run_config import * import os
from util.stringutils import * import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
llm = OpenAI( from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict
model_name=MODEL_NAME,
openai_api_key=OPENAI_API_KEY,
openai_api_base=default_if_blank(OPENAI_API_BASE), def get_llm_provider(llm_provider_name: str, llm_config_dict: dict):
temperature=TEMPERATURE, if llm_provider_name in llms.type_to_cls_dict:
) llm_provider = llms.type_to_cls_dict[llm_provider_name]
llm = llm_provider(**llm_config_dict)
return llm
else:
raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name))
llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict)

View File

@@ -1,6 +1,4 @@
from loguru import logger from loguru import logger
import sys
# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
def init_logger():
logger.remove()
logger.add("llmparser.info.log")

View File

@@ -1,10 +0,0 @@
def is_blank(s: str) -> bool:
return not (s and s.strip())
def is_not_blank(s: str) -> bool:
return not is_blank(s)
def default_if_blank(s: str, default: str = None) -> str:
return s if is_not_blank(s) else default

View File

@@ -4,7 +4,7 @@ from typing import List
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from run_config import HF_TEXT2VEC_MODEL_NAME from config.config_parse import HF_TEXT2VEC_MODEL_NAME
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME) hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)