mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Features that allow user to config LLM by config file. (#174)
This commit is contained in:
53
chat/core/src/main/python/config/config_parse.py
Normal file
53
chat/core/src/main/python/config/config_parse.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import configparser
|
||||
from util.logging_utils import logger
|
||||
|
||||
|
||||
def type_convert(input_str: str):
|
||||
try:
|
||||
return eval(input_str)
|
||||
except:
|
||||
return input_str
|
||||
|
||||
|
||||
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
config_file = "run_config.ini"
|
||||
config_path = os.path.join(PROJECT_DIR_PATH, config_file)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
|
||||
llm_parser_section_name = "LLMParser"
|
||||
LLMPARSER_HOST = config.get(llm_parser_section_name, 'LLMPARSER_HOST')
|
||||
LLMPARSER_PORT = int(config.get(llm_parser_section_name, 'LLMPARSER_PORT'))
|
||||
|
||||
chroma_db_section_name = "ChromaDB"
|
||||
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
|
||||
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
|
||||
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
|
||||
TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME')
|
||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM'))
|
||||
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
|
||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||
|
||||
text2vec_section_name = "Text2Vec"
|
||||
HF_TEXT2VEC_MODEL_NAME = config.get(text2vec_section_name, 'HF_TEXT2VEC_MODEL_NAME')
|
||||
|
||||
llm_provider_section_name = "LLMProvider"
|
||||
LLM_PROVIDER_NAME = config.get(llm_provider_section_name, 'LLM_PROVIDER_NAME')
|
||||
|
||||
llm_model_section_name = "LLMModel"
|
||||
llm_config_dict = {}
|
||||
for option in config.options(llm_model_section_name):
|
||||
llm_config_dict[option] = type_convert(config.get(llm_model_section_name, option))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
|
||||
logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
|
||||
logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
|
||||
logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST)
|
||||
logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT)
|
||||
logger.info("llm_config_dict: ", llm_config_dict)
|
||||
logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT)
|
||||
24
chat/core/src/main/python/config/run_config.ini
Normal file
24
chat/core/src/main/python/config/run_config.ini
Normal file
@@ -0,0 +1,24 @@
|
||||
[LLMParser]
|
||||
LLMPARSER_HOST = 127.0.0.1
|
||||
LLMPARSER_PORT = 9092
|
||||
|
||||
[ChromaDB]
|
||||
CHROMA_DB_PERSIST_DIR = chm_db
|
||||
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
|
||||
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
|
||||
TEXT2DSL_COLLECTION_NAME = text2dsl_collection
|
||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
||||
TEXT2DSL_IS_SHORTCUT = False
|
||||
|
||||
[Text2Vec]
|
||||
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese
|
||||
|
||||
|
||||
[LLMProvider]
|
||||
LLM_PROVIDER_NAME = openai
|
||||
|
||||
|
||||
[LLMModel]
|
||||
MODEL_NAME = gpt-3.5-turbo-16k
|
||||
OPENAI_API_KEY = YOUR_API_KEY
|
||||
TEMPERATURE = 0.0
|
||||
@@ -1,31 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
LLMPARSER_HOST = "127.0.0.1"
|
||||
LLMPARSER_PORT = 9092
|
||||
|
||||
MODEL_NAME = "gpt-3.5-turbo-16k"
|
||||
OPENAI_API_KEY = "YOUR_API_KEY"
|
||||
OPENAI_API_BASE = ""
|
||||
|
||||
TEMPERATURE = 0.0
|
||||
|
||||
CHROMA_DB_PERSIST_DIR = "chm_db"
|
||||
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
|
||||
SOLVED_QUERY_COLLECTION_NAME = "solved_query_collection"
|
||||
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
|
||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
||||
TEXT2DSL_IS_SHORTCUT = False
|
||||
|
||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||
|
||||
HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese"
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("PROJECT_DIR_PATH: {}", PROJECT_DIR_PATH)
|
||||
logger.info("EMB_MODEL_PATH: {}", HF_TEXT2VEC_MODEL_NAME)
|
||||
logger.info("CHROMA_DB_PERSIST_PATH: {}", CHROMA_DB_PERSIST_PATH)
|
||||
logger.info("LLMPARSER_HOST: {}", LLMPARSER_HOST)
|
||||
logger.info("LLMPARSER_PORT: {}", LLMPARSER_PORT)
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Union
|
||||
|
||||
from loguru import logger
|
||||
from util.logging_utils import logger
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
from chromadb.api import Collection
|
||||
|
||||
from preset_query_db import (
|
||||
@@ -18,7 +19,7 @@ from preset_query_db import (
|
||||
|
||||
from util.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
from run_config import PRESET_QUERY_COLLECTION_NAME
|
||||
from config.config_parse import PRESET_QUERY_COLLECTION_NAME
|
||||
from util.chromadb_instance import client
|
||||
|
||||
|
||||
|
||||
@@ -5,18 +5,18 @@ import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
from util.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
from run_config import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||
from util.chromadb_instance import (client,
|
||||
get_chroma_collection_size, query_chroma_collection,
|
||||
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
||||
|
||||
@@ -3,18 +3,18 @@ import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
||||
|
||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||
from util.text2vec import hg_embedding
|
||||
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
||||
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||
from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||
|
||||
|
||||
def reload_sql_example_collection(
|
||||
|
||||
@@ -6,12 +6,14 @@ from typing import List, Mapping
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT
|
||||
from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT)
|
||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
|
||||
from util.logging_utils import logger
|
||||
|
||||
|
||||
def text2dsl_setting_update(
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import re
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
|
||||
def schema_link_parse(schema_link_output):
|
||||
try:
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import List, Mapping
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
||||
|
||||
@@ -2,11 +2,11 @@ import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
from sql.prompt_maker import (
|
||||
schema_linking_exampler,
|
||||
sql_exampler,
|
||||
@@ -24,7 +24,7 @@ from sql.output_parser import (
|
||||
)
|
||||
|
||||
from util.llm_instance import llm
|
||||
from run_config import TEXT2DSL_IS_SHORTCUT
|
||||
from config.config_parse import TEXT2DSL_IS_SHORTCUT
|
||||
|
||||
|
||||
class Text2DSLAgent(object):
|
||||
@@ -65,10 +65,10 @@ class Text2DSLAgent(object):
|
||||
linking: Union[List[Mapping[str, str]], None] = None,
|
||||
):
|
||||
|
||||
logger.info("query_text: {}", query_text)
|
||||
logger.info("schema: {}", schema)
|
||||
logger.info("current_date: {}", current_date)
|
||||
logger.info("prior_schema_links: {}", linking)
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("schema: {}".format(schema))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(linking))
|
||||
|
||||
if linking is not None:
|
||||
prior_schema_links = {
|
||||
@@ -87,7 +87,7 @@ class Text2DSLAgent(object):
|
||||
prior_schema_links,
|
||||
self.sql_example_selector,
|
||||
)
|
||||
logger.info("schema_linking_prompt-> {}", schema_linking_prompt)
|
||||
logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt))
|
||||
schema_link_output = self.llm(schema_linking_prompt)
|
||||
schema_link_str = self.schema_link_parse(schema_link_output)
|
||||
|
||||
@@ -98,7 +98,7 @@ class Text2DSLAgent(object):
|
||||
current_date,
|
||||
self.sql_example_selector,
|
||||
)
|
||||
logger.info("sql_prompt->", sql_prompt)
|
||||
logger.info("sql_prompt-> {}".format(sql_prompt))
|
||||
sql_output = self.llm(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
@@ -113,7 +113,7 @@ class Text2DSLAgent(object):
|
||||
|
||||
resp["sqlOutput"] = sql_output
|
||||
|
||||
logger.info("resp: ", resp)
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
@@ -125,10 +125,10 @@ class Text2DSLAgent(object):
|
||||
linking: Union[List[Mapping[str, str]], None] = None,
|
||||
):
|
||||
|
||||
logger.info("query_text: ", query_text)
|
||||
logger.info("schema: ", schema)
|
||||
logger.info("current_date: ", current_date)
|
||||
logger.info("prior_schema_links: ", linking)
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("schema: {}".format(schema))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(linking))
|
||||
|
||||
if linking is not None:
|
||||
prior_schema_links = {
|
||||
@@ -148,7 +148,7 @@ class Text2DSLAgent(object):
|
||||
prior_schema_links,
|
||||
self.sql_example_selector,
|
||||
)
|
||||
logger.info("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
|
||||
logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt))
|
||||
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
|
||||
|
||||
schema_linking_str = self.combo_schema_link_parse(
|
||||
@@ -167,7 +167,7 @@ class Text2DSLAgent(object):
|
||||
resp["schemaLinkStr"] = schema_linking_str
|
||||
resp["sqlOutput"] = sql_str
|
||||
|
||||
logger.info("resp: ", resp)
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@@ -4,10 +4,6 @@ import sys
|
||||
|
||||
import uvicorn
|
||||
|
||||
from util.logging_utils import init_logger
|
||||
|
||||
init_logger()
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
@@ -15,7 +11,7 @@ from typing import Any, List, Mapping
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
|
||||
from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT
|
||||
|
||||
from services_router import (query2sql_service, preset_query_service,
|
||||
solved_query_service, plugin_call_service)
|
||||
|
||||
@@ -4,9 +4,15 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
import chromadb
|
||||
from chromadb.api import Collection
|
||||
from chromadb.config import Settings
|
||||
from loguru import logger
|
||||
|
||||
from run_config import CHROMA_DB_PERSIST_PATH
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
from config.config_parse import CHROMA_DB_PERSIST_PATH
|
||||
|
||||
client = chromadb.Client(
|
||||
Settings(
|
||||
@@ -75,7 +81,7 @@ def query_chroma_collection(collection:Collection, query_texts:List[str],
|
||||
else:
|
||||
outer_filter = None
|
||||
|
||||
logger.info('outer_filter: ', outer_filter)
|
||||
print('outer_filter: ', outer_filter)
|
||||
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
|
||||
return res
|
||||
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from langchain.llms import OpenAI
|
||||
from langchain import llms
|
||||
|
||||
from run_config import *
|
||||
from util.stringutils import *
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
llm = OpenAI(
|
||||
model_name=MODEL_NAME,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
openai_api_base=default_if_blank(OPENAI_API_BASE),
|
||||
temperature=TEMPERATURE,
|
||||
)
|
||||
from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict
|
||||
|
||||
|
||||
def get_llm_provider(llm_provider_name: str, llm_config_dict: dict):
|
||||
if llm_provider_name in llms.type_to_cls_dict:
|
||||
llm_provider = llms.type_to_cls_dict[llm_provider_name]
|
||||
llm = llm_provider(**llm_config_dict)
|
||||
return llm
|
||||
else:
|
||||
raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name))
|
||||
|
||||
|
||||
llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict)
|
||||
@@ -1,6 +1,4 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
|
||||
def init_logger():
|
||||
logger.remove()
|
||||
logger.add("llmparser.info.log")
|
||||
# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
|
||||
@@ -1,10 +0,0 @@
|
||||
def is_blank(s: str) -> bool:
|
||||
return not (s and s.strip())
|
||||
|
||||
|
||||
def is_not_blank(s: str) -> bool:
|
||||
return not is_blank(s)
|
||||
|
||||
|
||||
def default_if_blank(s: str, default: str = None) -> str:
|
||||
return s if is_not_blank(s) else default
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from run_config import HF_TEXT2VEC_MODEL_NAME
|
||||
from config.config_parse import HF_TEXT2VEC_MODEL_NAME
|
||||
|
||||
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user