diff --git a/chat/core/src/main/python/config/config_parse.py b/chat/core/src/main/python/config/config_parse.py new file mode 100644 index 000000000..f8425347d --- /dev/null +++ b/chat/core/src/main/python/config/config_parse.py @@ -0,0 +1,53 @@ +# -*- coding:utf-8 -*- +import os +import configparser +from util.logging_utils import logger + + +def type_convert(input_str: str): + try: + return eval(input_str) + except: + return input_str + + +PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) + +config_file = "run_config.ini" +config_path = os.path.join(PROJECT_DIR_PATH, config_file) +config = configparser.ConfigParser() +config.read(config_path) + +llm_parser_section_name = "LLMParser" +LLMPARSER_HOST = config.get(llm_parser_section_name, 'LLMPARSER_HOST') +LLMPARSER_PORT = int(config.get(llm_parser_section_name, 'LLMPARSER_PORT')) + +chroma_db_section_name = "ChromaDB" +CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR') +PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME') +SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME') +TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME') +TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM')) +TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT')) +CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) + +text2vec_section_name = "Text2Vec" +HF_TEXT2VEC_MODEL_NAME = config.get(text2vec_section_name, 'HF_TEXT2VEC_MODEL_NAME') + +llm_provider_section_name = "LLMProvider" +LLM_PROVIDER_NAME = config.get(llm_provider_section_name, 'LLM_PROVIDER_NAME') + +llm_model_section_name = "LLMModel" +llm_config_dict = {} +for option in config.options(llm_model_section_name): + llm_config_dict[option] = type_convert(config.get(llm_model_section_name, option)) + + +if __name__ == "__main__": + logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH) + logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME) + logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH) + logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST) + logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT) + logger.info("llm_config_dict: ", llm_config_dict) + logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT) diff --git a/chat/core/src/main/python/config/run_config.ini b/chat/core/src/main/python/config/run_config.ini new file mode 100644 index 000000000..35b0b5375 --- /dev/null +++ b/chat/core/src/main/python/config/run_config.ini @@ -0,0 +1,24 @@ +[LLMParser] +LLMPARSER_HOST = 127.0.0.1 +LLMPARSER_PORT = 9092 + +[ChromaDB] +CHROMA_DB_PERSIST_DIR = chm_db +PRESET_QUERY_COLLECTION_NAME = preset_query_collection +SOLVED_QUERY_COLLECTION_NAME = solved_query_collection +TEXT2DSL_COLLECTION_NAME = text2dsl_collection +TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 +TEXT2DSL_IS_SHORTCUT = False + +[Text2Vec] +HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese + + +[LLMProvider] +LLM_PROVIDER_NAME = openai + + +[LLMModel] +MODEL_NAME = gpt-3.5-turbo-16k +OPENAI_API_KEY = YOUR_API_KEY +TEMPERATURE = 0.0 diff --git a/chat/core/src/main/python/run_config.py b/chat/core/src/main/python/run_config.py deleted file mode 100644 index 8be2a98c7..000000000 --- a/chat/core/src/main/python/run_config.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding:utf-8 -*- -import os - -PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) - -LLMPARSER_HOST = "127.0.0.1" -LLMPARSER_PORT = 9092 - -MODEL_NAME = "gpt-3.5-turbo-16k" -OPENAI_API_KEY = "YOUR_API_KEY" -OPENAI_API_BASE = "" - -TEMPERATURE = 0.0 - -CHROMA_DB_PERSIST_DIR = "chm_db" -PRESET_QUERY_COLLECTION_NAME = "preset_query_collection" -SOLVED_QUERY_COLLECTION_NAME = "solved_query_collection" -TEXT2DSL_COLLECTION_NAME = "text2dsl_collection" -TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 -TEXT2DSL_IS_SHORTCUT = False - -CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) - -HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese" - -if __name__ == "__main__": - logger.info("PROJECT_DIR_PATH: {}", PROJECT_DIR_PATH) - logger.info("EMB_MODEL_PATH: {}", HF_TEXT2VEC_MODEL_NAME) - logger.info("CHROMA_DB_PERSIST_PATH: {}", CHROMA_DB_PERSIST_PATH) - logger.info("LLMPARSER_HOST: {}", LLMPARSER_HOST) - logger.info("LLMPARSER_PORT: {}", LLMPARSER_PORT) diff --git a/chat/core/src/main/python/services/plugin_call/prompt_construct.py b/chat/core/src/main/python/services/plugin_call/prompt_construct.py index a220acf5a..64b243578 100644 --- a/chat/core/src/main/python/services/plugin_call/prompt_construct.py +++ b/chat/core/src/main/python/services/plugin_call/prompt_construct.py @@ -5,7 +5,7 @@ import re import sys from typing import Any, List, Mapping, Union -from loguru import logger +from util.logging_utils import logger sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) diff --git a/chat/core/src/main/python/services/preset_retrieval/run.py b/chat/core/src/main/python/services/preset_retrieval/run.py index a2df902b6..5c3e9adae 100644 --- a/chat/core/src/main/python/services/preset_retrieval/run.py +++ b/chat/core/src/main/python/services/preset_retrieval/run.py @@ -7,6 +7,7 @@ from typing import List sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from util.logging_utils import logger from chromadb.api import Collection from preset_query_db import ( @@ -18,7 +19,7 @@ from preset_query_db import ( from util.text2vec import Text2VecEmbeddingFunction -from run_config import PRESET_QUERY_COLLECTION_NAME +from config.config_parse import PRESET_QUERY_COLLECTION_NAME from util.chromadb_instance import client diff --git a/chat/core/src/main/python/services/query_retrieval/run.py b/chat/core/src/main/python/services/query_retrieval/run.py index 93f135a0b..010130232 100644 --- a/chat/core/src/main/python/services/query_retrieval/run.py +++ b/chat/core/src/main/python/services/query_retrieval/run.py @@ -5,18 +5,18 @@ import sys import uuid from typing import Any, List, Mapping, Optional, Union -from loguru import logger - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from util.logging_utils import logger + import chromadb from chromadb.config import Settings from chromadb.api import Collection, Documents, Embeddings from util.text2vec import Text2VecEmbeddingFunction -from run_config import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME +from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME from util.chromadb_instance import (client, get_chroma_collection_size, query_chroma_collection, parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format, diff --git a/chat/core/src/main/python/services/sql/constructor.py b/chat/core/src/main/python/services/sql/constructor.py index aed344835..e313e765f 100644 --- a/chat/core/src/main/python/services/sql/constructor.py +++ b/chat/core/src/main/python/services/sql/constructor.py @@ -3,18 +3,18 @@ import os import sys from typing import List, Mapping -from loguru import logger - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from util.logging_utils import logger + from langchain.vectorstores import Chroma from langchain.prompts.example_selector import SemanticSimilarityExampleSelector from few_shot_example.sql_exampler import examplars as sql_examplars from util.text2vec import hg_embedding from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 -from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM +from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM def reload_sql_example_collection( diff --git a/chat/core/src/main/python/services/sql/examples_reload_run.py b/chat/core/src/main/python/services/sql/examples_reload_run.py index ea374932a..cf5634c8c 100644 --- a/chat/core/src/main/python/services/sql/examples_reload_run.py +++ b/chat/core/src/main/python/services/sql/examples_reload_run.py @@ -6,12 +6,14 @@ from typing import List, Mapping import requests +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT +from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT, + LLMPARSER_HOST, LLMPARSER_PORT) from few_shot_example.sql_exampler import examplars as sql_examplars -from run_config import LLMPARSER_HOST, LLMPARSER_PORT +from util.logging_utils import logger def text2dsl_setting_update( diff --git a/chat/core/src/main/python/services/sql/output_parser.py b/chat/core/src/main/python/services/sql/output_parser.py index 210c2e1f3..ee07132eb 100644 --- a/chat/core/src/main/python/services/sql/output_parser.py +++ b/chat/core/src/main/python/services/sql/output_parser.py @@ -1,6 +1,15 @@ # -*- coding:utf-8 -*- import re +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from util.logging_utils import logger + def schema_link_parse(schema_link_output): try: diff --git a/chat/core/src/main/python/services/sql/prompt_maker.py b/chat/core/src/main/python/services/sql/prompt_maker.py index 892ab7aa4..3ed8d84ad 100644 --- a/chat/core/src/main/python/services/sql/prompt_maker.py +++ b/chat/core/src/main/python/services/sql/prompt_maker.py @@ -6,6 +6,8 @@ from typing import List, Mapping sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from util.logging_utils import logger + from langchain.prompts import PromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.example_selector import SemanticSimilarityExampleSelector diff --git a/chat/core/src/main/python/services/sql/run.py b/chat/core/src/main/python/services/sql/run.py index 2dbad25bb..623c8cfdd 100644 --- a/chat/core/src/main/python/services/sql/run.py +++ b/chat/core/src/main/python/services/sql/run.py @@ -2,11 +2,11 @@ import os import sys from typing import List, Union, Mapping -from loguru import logger - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from util.logging_utils import logger + from sql.prompt_maker import ( schema_linking_exampler, sql_exampler, @@ -24,7 +24,7 @@ from sql.output_parser import ( ) from util.llm_instance import llm -from run_config import TEXT2DSL_IS_SHORTCUT +from config.config_parse import TEXT2DSL_IS_SHORTCUT class Text2DSLAgent(object): @@ -65,10 +65,10 @@ class Text2DSLAgent(object): linking: Union[List[Mapping[str, str]], None] = None, ): - logger.info("query_text: {}", query_text) - logger.info("schema: {}", schema) - logger.info("current_date: {}", current_date) - logger.info("prior_schema_links: {}", linking) + logger.info("query_text: {}".format(query_text)) + logger.info("schema: {}".format(schema)) + logger.info("current_date: {}".format(current_date)) + logger.info("prior_schema_links: {}".format(linking)) if linking is not None: prior_schema_links = { @@ -87,7 +87,7 @@ class Text2DSLAgent(object): prior_schema_links, self.sql_example_selector, ) - logger.info("schema_linking_prompt-> {}", schema_linking_prompt) + logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt)) schema_link_output = self.llm(schema_linking_prompt) schema_link_str = self.schema_link_parse(schema_link_output) @@ -98,7 +98,7 @@ class Text2DSLAgent(object): current_date, self.sql_example_selector, ) - logger.info("sql_prompt->", sql_prompt) + logger.info("sql_prompt-> {}".format(sql_prompt)) sql_output = self.llm(sql_prompt) resp = dict() @@ -113,7 +113,7 @@ class Text2DSLAgent(object): resp["sqlOutput"] = sql_output - logger.info("resp: ", resp) + logger.info("resp: {}".format(resp)) return resp @@ -125,10 +125,10 @@ class Text2DSLAgent(object): linking: Union[List[Mapping[str, str]], None] = None, ): - logger.info("query_text: ", query_text) - logger.info("schema: ", schema) - logger.info("current_date: ", current_date) - logger.info("prior_schema_links: ", linking) + logger.info("query_text: {}".format(query_text)) + logger.info("schema: {}".format(schema)) + logger.info("current_date: {}".format(current_date)) + logger.info("prior_schema_links: {}".format(linking)) if linking is not None: prior_schema_links = { @@ -148,7 +148,7 @@ class Text2DSLAgent(object): prior_schema_links, self.sql_example_selector, ) - logger.info("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt) + logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt)) schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) schema_linking_str = self.combo_schema_link_parse( @@ -167,7 +167,7 @@ class Text2DSLAgent(object): resp["schemaLinkStr"] = schema_linking_str resp["sqlOutput"] = sql_str - logger.info("resp: ", resp) + logger.info("resp: {}".format(resp)) return resp diff --git a/chat/core/src/main/python/supersonic_llmparser.py b/chat/core/src/main/python/supersonic_llmparser.py index db7c90a3c..2536d3b82 100644 --- a/chat/core/src/main/python/supersonic_llmparser.py +++ b/chat/core/src/main/python/supersonic_llmparser.py @@ -4,10 +4,6 @@ import sys import uvicorn -from util.logging_utils import init_logger - -init_logger() - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) @@ -15,7 +11,7 @@ from typing import Any, List, Mapping from fastapi import FastAPI, HTTPException -from run_config import LLMPARSER_HOST, LLMPARSER_PORT +from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT from services_router import (query2sql_service, preset_query_service, solved_query_service, plugin_call_service) diff --git a/chat/core/src/main/python/util/chromadb_instance.py b/chat/core/src/main/python/util/chromadb_instance.py index c6fcf1953..4c2f6824b 100644 --- a/chat/core/src/main/python/util/chromadb_instance.py +++ b/chat/core/src/main/python/util/chromadb_instance.py @@ -4,9 +4,15 @@ from typing import Any, List, Mapping, Optional, Union import chromadb from chromadb.api import Collection from chromadb.config import Settings -from loguru import logger -from run_config import CHROMA_DB_PERSIST_PATH +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from util.logging_utils import logger + +from config.config_parse import CHROMA_DB_PERSIST_PATH client = chromadb.Client( Settings( @@ -75,7 +81,7 @@ def query_chroma_collection(collection:Collection, query_texts:List[str], else: outer_filter = None - logger.info('outer_filter: ', outer_filter) + print('outer_filter: ', outer_filter) res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter) return res diff --git a/chat/core/src/main/python/util/llm_instance.py b/chat/core/src/main/python/util/llm_instance.py index 07cf8586f..7c8a468b0 100644 --- a/chat/core/src/main/python/util/llm_instance.py +++ b/chat/core/src/main/python/util/llm_instance.py @@ -1,12 +1,21 @@ # -*- coding:utf-8 -*- -from langchain.llms import OpenAI +from langchain import llms -from run_config import * -from util.stringutils import * +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) -llm = OpenAI( - model_name=MODEL_NAME, - openai_api_key=OPENAI_API_KEY, - openai_api_base=default_if_blank(OPENAI_API_BASE), - temperature=TEMPERATURE, -) +from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict + + +def get_llm_provider(llm_provider_name: str, llm_config_dict: dict): + if llm_provider_name in llms.type_to_cls_dict: + llm_provider = llms.type_to_cls_dict[llm_provider_name] + llm = llm_provider(**llm_config_dict) + return llm + else: + raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name)) + + +llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict) \ No newline at end of file diff --git a/chat/core/src/main/python/util/logging_utils.py b/chat/core/src/main/python/util/logging_utils.py index 8928db9b6..359ab8e30 100644 --- a/chat/core/src/main/python/util/logging_utils.py +++ b/chat/core/src/main/python/util/logging_utils.py @@ -1,6 +1,4 @@ from loguru import logger +import sys - -def init_logger(): - logger.remove() - logger.add("llmparser.info.log") +# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}") \ No newline at end of file diff --git a/chat/core/src/main/python/util/stringutils.py b/chat/core/src/main/python/util/stringutils.py deleted file mode 100644 index 5b02b792d..000000000 --- a/chat/core/src/main/python/util/stringutils.py +++ /dev/null @@ -1,10 +0,0 @@ -def is_blank(s: str) -> bool: - return not (s and s.strip()) - - -def is_not_blank(s: str) -> bool: - return not is_blank(s) - - -def default_if_blank(s: str, default: str = None) -> str: - return s if is_not_blank(s) else default diff --git a/chat/core/src/main/python/util/text2vec.py b/chat/core/src/main/python/util/text2vec.py index 8c56c561f..d359a4084 100644 --- a/chat/core/src/main/python/util/text2vec.py +++ b/chat/core/src/main/python/util/text2vec.py @@ -4,7 +4,7 @@ from typing import List from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from langchain.embeddings import HuggingFaceEmbeddings -from run_config import HF_TEXT2VEC_MODEL_NAME +from config.config_parse import HF_TEXT2VEC_MODEL_NAME hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)