mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Features that allow user to config LLM by config file. (#174)
This commit is contained in:
53
chat/core/src/main/python/config/config_parse.py
Normal file
53
chat/core/src/main/python/config/config_parse.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
import os
|
||||||
|
import configparser
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
def type_convert(input_str: str):
|
||||||
|
try:
|
||||||
|
return eval(input_str)
|
||||||
|
except:
|
||||||
|
return input_str
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
config_file = "run_config.ini"
|
||||||
|
config_path = os.path.join(PROJECT_DIR_PATH, config_file)
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(config_path)
|
||||||
|
|
||||||
|
llm_parser_section_name = "LLMParser"
|
||||||
|
LLMPARSER_HOST = config.get(llm_parser_section_name, 'LLMPARSER_HOST')
|
||||||
|
LLMPARSER_PORT = int(config.get(llm_parser_section_name, 'LLMPARSER_PORT'))
|
||||||
|
|
||||||
|
chroma_db_section_name = "ChromaDB"
|
||||||
|
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
|
||||||
|
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
|
||||||
|
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
|
||||||
|
TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME')
|
||||||
|
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM'))
|
||||||
|
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
|
||||||
|
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||||
|
|
||||||
|
text2vec_section_name = "Text2Vec"
|
||||||
|
HF_TEXT2VEC_MODEL_NAME = config.get(text2vec_section_name, 'HF_TEXT2VEC_MODEL_NAME')
|
||||||
|
|
||||||
|
llm_provider_section_name = "LLMProvider"
|
||||||
|
LLM_PROVIDER_NAME = config.get(llm_provider_section_name, 'LLM_PROVIDER_NAME')
|
||||||
|
|
||||||
|
llm_model_section_name = "LLMModel"
|
||||||
|
llm_config_dict = {}
|
||||||
|
for option in config.options(llm_model_section_name):
|
||||||
|
llm_config_dict[option] = type_convert(config.get(llm_model_section_name, option))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
|
||||||
|
logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
|
||||||
|
logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
|
||||||
|
logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST)
|
||||||
|
logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT)
|
||||||
|
logger.info("llm_config_dict: ", llm_config_dict)
|
||||||
|
logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT)
|
||||||
24
chat/core/src/main/python/config/run_config.ini
Normal file
24
chat/core/src/main/python/config/run_config.ini
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[LLMParser]
|
||||||
|
LLMPARSER_HOST = 127.0.0.1
|
||||||
|
LLMPARSER_PORT = 9092
|
||||||
|
|
||||||
|
[ChromaDB]
|
||||||
|
CHROMA_DB_PERSIST_DIR = chm_db
|
||||||
|
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
|
||||||
|
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
|
||||||
|
TEXT2DSL_COLLECTION_NAME = text2dsl_collection
|
||||||
|
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
||||||
|
TEXT2DSL_IS_SHORTCUT = False
|
||||||
|
|
||||||
|
[Text2Vec]
|
||||||
|
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese
|
||||||
|
|
||||||
|
|
||||||
|
[LLMProvider]
|
||||||
|
LLM_PROVIDER_NAME = openai
|
||||||
|
|
||||||
|
|
||||||
|
[LLMModel]
|
||||||
|
MODEL_NAME = gpt-3.5-turbo-16k
|
||||||
|
OPENAI_API_KEY = YOUR_API_KEY
|
||||||
|
TEMPERATURE = 0.0
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
import os
|
|
||||||
|
|
||||||
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
LLMPARSER_HOST = "127.0.0.1"
|
|
||||||
LLMPARSER_PORT = 9092
|
|
||||||
|
|
||||||
MODEL_NAME = "gpt-3.5-turbo-16k"
|
|
||||||
OPENAI_API_KEY = "YOUR_API_KEY"
|
|
||||||
OPENAI_API_BASE = ""
|
|
||||||
|
|
||||||
TEMPERATURE = 0.0
|
|
||||||
|
|
||||||
CHROMA_DB_PERSIST_DIR = "chm_db"
|
|
||||||
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
|
|
||||||
SOLVED_QUERY_COLLECTION_NAME = "solved_query_collection"
|
|
||||||
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
|
|
||||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
|
||||||
TEXT2DSL_IS_SHORTCUT = False
|
|
||||||
|
|
||||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
|
||||||
|
|
||||||
HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese"
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logger.info("PROJECT_DIR_PATH: {}", PROJECT_DIR_PATH)
|
|
||||||
logger.info("EMB_MODEL_PATH: {}", HF_TEXT2VEC_MODEL_NAME)
|
|
||||||
logger.info("CHROMA_DB_PERSIST_PATH: {}", CHROMA_DB_PERSIST_PATH)
|
|
||||||
logger.info("LLMPARSER_HOST: {}", LLMPARSER_HOST)
|
|
||||||
logger.info("LLMPARSER_PORT: {}", LLMPARSER_PORT)
|
|
||||||
@@ -5,7 +5,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any, List, Mapping, Union
|
from typing import Any, List, Mapping, Union
|
||||||
|
|
||||||
from loguru import logger
|
from util.logging_utils import logger
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import List
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
from chromadb.api import Collection
|
from chromadb.api import Collection
|
||||||
|
|
||||||
from preset_query_db import (
|
from preset_query_db import (
|
||||||
@@ -18,7 +19,7 @@ from preset_query_db import (
|
|||||||
|
|
||||||
from util.text2vec import Text2VecEmbeddingFunction
|
from util.text2vec import Text2VecEmbeddingFunction
|
||||||
|
|
||||||
from run_config import PRESET_QUERY_COLLECTION_NAME
|
from config.config_parse import PRESET_QUERY_COLLECTION_NAME
|
||||||
from util.chromadb_instance import client
|
from util.chromadb_instance import client
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,18 +5,18 @@ import sys
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, List, Mapping, Optional, Union
|
from typing import Any, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from chromadb.api import Collection, Documents, Embeddings
|
from chromadb.api import Collection, Documents, Embeddings
|
||||||
|
|
||||||
from util.text2vec import Text2VecEmbeddingFunction
|
from util.text2vec import Text2VecEmbeddingFunction
|
||||||
|
|
||||||
from run_config import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||||
from util.chromadb_instance import (client,
|
from util.chromadb_instance import (client,
|
||||||
get_chroma_collection_size, query_chroma_collection,
|
get_chroma_collection_size, query_chroma_collection,
|
||||||
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
||||||
|
|||||||
@@ -3,18 +3,18 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from typing import List, Mapping
|
from typing import List, Mapping
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
||||||
|
|
||||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||||
from util.text2vec import hg_embedding
|
from util.text2vec import hg_embedding
|
||||||
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
||||||
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||||
|
|
||||||
|
|
||||||
def reload_sql_example_collection(
|
def reload_sql_example_collection(
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ from typing import List, Mapping
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT
|
from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT,
|
||||||
|
LLMPARSER_HOST, LLMPARSER_PORT)
|
||||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||||
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
|
from util.logging_utils import logger
|
||||||
|
|
||||||
|
|
||||||
def text2dsl_setting_update(
|
def text2dsl_setting_update(
|
||||||
|
|||||||
@@ -1,6 +1,15 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
|
|
||||||
def schema_link_parse(schema_link_output):
|
def schema_link_parse(schema_link_output):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import List, Mapping
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from typing import List, Union, Mapping
|
from typing import List, Union, Mapping
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
from sql.prompt_maker import (
|
from sql.prompt_maker import (
|
||||||
schema_linking_exampler,
|
schema_linking_exampler,
|
||||||
sql_exampler,
|
sql_exampler,
|
||||||
@@ -24,7 +24,7 @@ from sql.output_parser import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from util.llm_instance import llm
|
from util.llm_instance import llm
|
||||||
from run_config import TEXT2DSL_IS_SHORTCUT
|
from config.config_parse import TEXT2DSL_IS_SHORTCUT
|
||||||
|
|
||||||
|
|
||||||
class Text2DSLAgent(object):
|
class Text2DSLAgent(object):
|
||||||
@@ -65,10 +65,10 @@ class Text2DSLAgent(object):
|
|||||||
linking: Union[List[Mapping[str, str]], None] = None,
|
linking: Union[List[Mapping[str, str]], None] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
logger.info("query_text: {}", query_text)
|
logger.info("query_text: {}".format(query_text))
|
||||||
logger.info("schema: {}", schema)
|
logger.info("schema: {}".format(schema))
|
||||||
logger.info("current_date: {}", current_date)
|
logger.info("current_date: {}".format(current_date))
|
||||||
logger.info("prior_schema_links: {}", linking)
|
logger.info("prior_schema_links: {}".format(linking))
|
||||||
|
|
||||||
if linking is not None:
|
if linking is not None:
|
||||||
prior_schema_links = {
|
prior_schema_links = {
|
||||||
@@ -87,7 +87,7 @@ class Text2DSLAgent(object):
|
|||||||
prior_schema_links,
|
prior_schema_links,
|
||||||
self.sql_example_selector,
|
self.sql_example_selector,
|
||||||
)
|
)
|
||||||
logger.info("schema_linking_prompt-> {}", schema_linking_prompt)
|
logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt))
|
||||||
schema_link_output = self.llm(schema_linking_prompt)
|
schema_link_output = self.llm(schema_linking_prompt)
|
||||||
schema_link_str = self.schema_link_parse(schema_link_output)
|
schema_link_str = self.schema_link_parse(schema_link_output)
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ class Text2DSLAgent(object):
|
|||||||
current_date,
|
current_date,
|
||||||
self.sql_example_selector,
|
self.sql_example_selector,
|
||||||
)
|
)
|
||||||
logger.info("sql_prompt->", sql_prompt)
|
logger.info("sql_prompt-> {}".format(sql_prompt))
|
||||||
sql_output = self.llm(sql_prompt)
|
sql_output = self.llm(sql_prompt)
|
||||||
|
|
||||||
resp = dict()
|
resp = dict()
|
||||||
@@ -113,7 +113,7 @@ class Text2DSLAgent(object):
|
|||||||
|
|
||||||
resp["sqlOutput"] = sql_output
|
resp["sqlOutput"] = sql_output
|
||||||
|
|
||||||
logger.info("resp: ", resp)
|
logger.info("resp: {}".format(resp))
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@@ -125,10 +125,10 @@ class Text2DSLAgent(object):
|
|||||||
linking: Union[List[Mapping[str, str]], None] = None,
|
linking: Union[List[Mapping[str, str]], None] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
logger.info("query_text: ", query_text)
|
logger.info("query_text: {}".format(query_text))
|
||||||
logger.info("schema: ", schema)
|
logger.info("schema: {}".format(schema))
|
||||||
logger.info("current_date: ", current_date)
|
logger.info("current_date: {}".format(current_date))
|
||||||
logger.info("prior_schema_links: ", linking)
|
logger.info("prior_schema_links: {}".format(linking))
|
||||||
|
|
||||||
if linking is not None:
|
if linking is not None:
|
||||||
prior_schema_links = {
|
prior_schema_links = {
|
||||||
@@ -148,7 +148,7 @@ class Text2DSLAgent(object):
|
|||||||
prior_schema_links,
|
prior_schema_links,
|
||||||
self.sql_example_selector,
|
self.sql_example_selector,
|
||||||
)
|
)
|
||||||
logger.info("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
|
logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt))
|
||||||
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
|
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
|
||||||
|
|
||||||
schema_linking_str = self.combo_schema_link_parse(
|
schema_linking_str = self.combo_schema_link_parse(
|
||||||
@@ -167,7 +167,7 @@ class Text2DSLAgent(object):
|
|||||||
resp["schemaLinkStr"] = schema_linking_str
|
resp["schemaLinkStr"] = schema_linking_str
|
||||||
resp["sqlOutput"] = sql_str
|
resp["sqlOutput"] = sql_str
|
||||||
|
|
||||||
logger.info("resp: ", resp)
|
logger.info("resp: {}".format(resp))
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,6 @@ import sys
|
|||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from util.logging_utils import init_logger
|
|
||||||
|
|
||||||
init_logger()
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
@@ -15,7 +11,7 @@ from typing import Any, List, Mapping
|
|||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
|
||||||
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
|
from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT
|
||||||
|
|
||||||
from services_router import (query2sql_service, preset_query_service,
|
from services_router import (query2sql_service, preset_query_service,
|
||||||
solved_query_service, plugin_call_service)
|
solved_query_service, plugin_call_service)
|
||||||
|
|||||||
@@ -4,9 +4,15 @@ from typing import Any, List, Mapping, Optional, Union
|
|||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.api import Collection
|
from chromadb.api import Collection
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from run_config import CHROMA_DB_PERSIST_PATH
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from util.logging_utils import logger
|
||||||
|
|
||||||
|
from config.config_parse import CHROMA_DB_PERSIST_PATH
|
||||||
|
|
||||||
client = chromadb.Client(
|
client = chromadb.Client(
|
||||||
Settings(
|
Settings(
|
||||||
@@ -75,7 +81,7 @@ def query_chroma_collection(collection:Collection, query_texts:List[str],
|
|||||||
else:
|
else:
|
||||||
outer_filter = None
|
outer_filter = None
|
||||||
|
|
||||||
logger.info('outer_filter: ', outer_filter)
|
print('outer_filter: ', outer_filter)
|
||||||
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
|
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,21 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
from langchain.llms import OpenAI
|
from langchain import llms
|
||||||
|
|
||||||
from run_config import *
|
import os
|
||||||
from util.stringutils import *
|
import sys
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
llm = OpenAI(
|
from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict
|
||||||
model_name=MODEL_NAME,
|
|
||||||
openai_api_key=OPENAI_API_KEY,
|
|
||||||
openai_api_base=default_if_blank(OPENAI_API_BASE),
|
def get_llm_provider(llm_provider_name: str, llm_config_dict: dict):
|
||||||
temperature=TEMPERATURE,
|
if llm_provider_name in llms.type_to_cls_dict:
|
||||||
)
|
llm_provider = llms.type_to_cls_dict[llm_provider_name]
|
||||||
|
llm = llm_provider(**llm_config_dict)
|
||||||
|
return llm
|
||||||
|
else:
|
||||||
|
raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name))
|
||||||
|
|
||||||
|
|
||||||
|
llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict)
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
|
||||||
def init_logger():
|
|
||||||
logger.remove()
|
|
||||||
logger.add("llmparser.info.log")
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
def is_blank(s: str) -> bool:
|
|
||||||
return not (s and s.strip())
|
|
||||||
|
|
||||||
|
|
||||||
def is_not_blank(s: str) -> bool:
|
|
||||||
return not is_blank(s)
|
|
||||||
|
|
||||||
|
|
||||||
def default_if_blank(s: str, default: str = None) -> str:
|
|
||||||
return s if is_not_blank(s) else default
|
|
||||||
@@ -4,7 +4,7 @@ from typing import List
|
|||||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from run_config import HF_TEXT2VEC_MODEL_NAME
|
from config.config_parse import HF_TEXT2VEC_MODEL_NAME
|
||||||
|
|
||||||
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)
|
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user