mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
Merge branch 'master' into feature/metricAddTag
This commit is contained in:
@@ -16,4 +16,10 @@ public class LLMParserConfig {
|
|||||||
@Value("${query2sql.path:/query2sql}")
|
@Value("${query2sql.path:/query2sql}")
|
||||||
private String queryToSqlPath;
|
private String queryToSqlPath;
|
||||||
|
|
||||||
|
@Value("${dimension.topn:5}")
|
||||||
|
private Integer dimensionTopN;
|
||||||
|
|
||||||
|
@Value("${metric.topn:5}")
|
||||||
|
private Integer metricTopN;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
|||||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -87,7 +88,7 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLMReq llmReq = getLlmReq(queryCtx, modelId);
|
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
|
||||||
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
|
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
|
||||||
|
|
||||||
if (Objects.isNull(llmResp)) {
|
if (Objects.isNull(llmResp)) {
|
||||||
@@ -340,22 +341,28 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
|
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||||
String queryText = queryCtx.getRequest().getQueryText();
|
String queryText = queryCtx.getRequest().getQueryText();
|
||||||
|
|
||||||
LLMReq llmReq = new LLMReq();
|
LLMReq llmReq = new LLMReq();
|
||||||
llmReq.setQueryText(queryText);
|
llmReq.setQueryText(queryText);
|
||||||
|
|
||||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
|
|
||||||
|
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
|
||||||
|
|
||||||
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
|
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
|
||||||
llmSchema.setFieldNameList(fieldNameList);
|
llmSchema.setFieldNameList(fieldNameList);
|
||||||
llmReq.setSchema(llmSchema);
|
llmReq.setSchema(llmSchema);
|
||||||
|
|
||||||
List<ElementValue> linking = new ArrayList<>();
|
List<ElementValue> linking = new ArrayList<>();
|
||||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||||
llmReq.setLinking(linking);
|
llmReq.setLinking(linking);
|
||||||
|
|
||||||
String currentDate = DSLDateHelper.getReferenceDate(modelId);
|
String currentDate = DSLDateHelper.getReferenceDate(modelId);
|
||||||
llmReq.setCurrentDate(currentDate);
|
llmReq.setCurrentDate(currentDate);
|
||||||
return llmReq;
|
return llmReq;
|
||||||
@@ -399,12 +406,29 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
||||||
|
LLMParserConfig llmParserConfig) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||||
|
|
||||||
|
Set<String> results = semanticSchema.getDimensions().stream()
|
||||||
|
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||||
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
|
.limit(llmParserConfig.getDimensionTopN())
|
||||||
|
.map(entry -> entry.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
Set<String> metrics = semanticSchema.getMetrics().stream()
|
||||||
|
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||||
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
|
.limit(llmParserConfig.getMetricTopN())
|
||||||
|
.map(entry -> entry.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
results.addAll(metrics);
|
||||||
|
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
Set<String> fieldNameList = matchedElements.stream()
|
Set<String> fieldNameList = matchedElements.stream()
|
||||||
.filter(schemaElementMatch -> {
|
.filter(schemaElementMatch -> {
|
||||||
@@ -423,7 +447,8 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
})
|
})
|
||||||
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
return new ArrayList<>(fieldNameList);
|
results.addAll(fieldNameList);
|
||||||
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ CHROMA_DB_PERSIST_DIR = 'chm_db'
|
|||||||
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
|
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
|
||||||
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
|
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
|
||||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
||||||
|
TEXT2DSL_IS_SHORTCUT = False
|
||||||
|
|
||||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||||
|
|
||||||
|
|||||||
@@ -22,10 +22,8 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
|
|||||||
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
||||||
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||||
|
|
||||||
|
|
||||||
def reload_sql_example_collection(vectorstore:Chroma,
|
def reload_sql_example_collection(vectorstore:Chroma,
|
||||||
sql_examplars:List[Mapping[str, str]],
|
sql_examplars:List[Mapping[str, str]],
|
||||||
schema_linking_example_selector:SemanticSimilarityExampleSelector,
|
|
||||||
sql_example_selector:SemanticSimilarityExampleSelector,
|
sql_example_selector:SemanticSimilarityExampleSelector,
|
||||||
example_nums:int
|
example_nums:int
|
||||||
):
|
):
|
||||||
@@ -35,20 +33,16 @@ def reload_sql_example_collection(vectorstore:Chroma,
|
|||||||
|
|
||||||
print("emptied sql_examples_collection size:", vectorstore._collection.count())
|
print("emptied sql_examples_collection size:", vectorstore._collection.count())
|
||||||
|
|
||||||
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
|
||||||
input_keys=["question"],
|
|
||||||
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
|
|
||||||
|
|
||||||
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
||||||
input_keys=["question"],
|
input_keys=["question"],
|
||||||
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
|
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
|
||||||
|
|
||||||
for example in sql_examplars:
|
for example in sql_examplars:
|
||||||
schema_linking_example_selector.add_example(example)
|
sql_example_selector.add_example(example)
|
||||||
|
|
||||||
print("reloaded sql_examples_collection size:", vectorstore._collection.count())
|
print("reloaded sql_examples_collection size:", vectorstore._collection.count())
|
||||||
|
|
||||||
return vectorstore, schema_linking_example_selector, sql_example_selector
|
return vectorstore, sql_example_selector
|
||||||
|
|
||||||
|
|
||||||
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
|
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
|
||||||
@@ -57,22 +51,14 @@ sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
|
|||||||
|
|
||||||
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||||
|
|
||||||
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
|
||||||
input_keys=["question"],
|
|
||||||
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
|
|
||||||
|
|
||||||
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
||||||
input_keys=["question"],
|
input_keys=["question"],
|
||||||
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
|
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
|
||||||
|
|
||||||
if sql_examples_vectorstore._collection.count() > 0:
|
if sql_examples_vectorstore._collection.count() > 0:
|
||||||
print("examples already in sql_vectorstore")
|
print("examples already in sql_vectorstore")
|
||||||
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
||||||
if sql_examples_vectorstore._collection.count() < len(sql_examplars):
|
|
||||||
print("sql_examplars size:", len(sql_examplars))
|
|
||||||
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
|
|
||||||
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
|
||||||
else:
|
|
||||||
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
|
|
||||||
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
|
||||||
|
|
||||||
|
print("sql_examplars size:", len(sql_examplars))
|
||||||
|
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums)
|
||||||
|
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
||||||
|
|||||||
@@ -8,24 +8,22 @@ import json
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT
|
||||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||||
from run_config import LLMPARSER_HOST
|
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
|
||||||
from run_config import LLMPARSER_PORT
|
|
||||||
|
|
||||||
|
|
||||||
def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str,
|
def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str,
|
||||||
sql_examplars:List[Mapping[str, str]], example_nums:int):
|
sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool):
|
||||||
|
|
||||||
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
|
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
|
||||||
print("url: ", url)
|
print("url: ", url)
|
||||||
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
|
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut}
|
||||||
headers = {'content-type': 'application/json'}
|
headers = {'content-type': 'application/json'}
|
||||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
arguments = sys.argv
|
|
||||||
text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT,
|
text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT,
|
||||||
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM)
|
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT)
|
||||||
|
|||||||
@@ -10,4 +10,36 @@ def schema_link_parse(schema_link_output):
|
|||||||
print(e)
|
print(e)
|
||||||
schema_link_output = None
|
schema_link_output = None
|
||||||
|
|
||||||
return schema_link_output
|
return schema_link_output
|
||||||
|
|
||||||
|
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
|
||||||
|
try:
|
||||||
|
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||||
|
pattern = r'Schema_links:(\[.*?\])'
|
||||||
|
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||||
|
|
||||||
|
if schema_links_match:
|
||||||
|
schema_links = schema_links_match.group(1)
|
||||||
|
else:
|
||||||
|
schema_links = None
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
schema_links = None
|
||||||
|
|
||||||
|
return schema_links
|
||||||
|
|
||||||
|
def combo_sql_parse(schema_linking_sql_combo_output: str):
|
||||||
|
try:
|
||||||
|
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||||
|
pattern = r'SQL:(.*)'
|
||||||
|
sql_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||||
|
|
||||||
|
if sql_match:
|
||||||
|
sql = sql_match.group(1)
|
||||||
|
else:
|
||||||
|
sql = None
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
sql = None
|
||||||
|
|
||||||
|
return sql
|
||||||
|
|||||||
@@ -73,3 +73,38 @@ def sql_exampler(user_query: str,
|
|||||||
schema_links=schema_link_str)
|
schema_links=schema_link_str)
|
||||||
|
|
||||||
return sql_example_prompt
|
return sql_example_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def schema_linking_sql_combo_examplar(user_query: str,
|
||||||
|
domain_name: str,
|
||||||
|
data_date : str,
|
||||||
|
fields_list: List[str],
|
||||||
|
prior_schema_links: Mapping[str,str],
|
||||||
|
example_selector: SemanticSimilarityExampleSelector) -> str:
|
||||||
|
|
||||||
|
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||||
|
|
||||||
|
example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"],
|
||||||
|
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}")
|
||||||
|
|
||||||
|
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||||
|
|
||||||
|
schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||||
|
|
||||||
|
schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate(
|
||||||
|
example_selector=example_selector,
|
||||||
|
example_prompt=example_prompt_template,
|
||||||
|
example_separator="\n\n",
|
||||||
|
prefix=instruction,
|
||||||
|
input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"],
|
||||||
|
suffix=schema_linking_sql_combo_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name,
|
||||||
|
fields_list=fields_list,
|
||||||
|
prior_schema_links=prior_schema_links_str,
|
||||||
|
current_date=data_date,
|
||||||
|
question=user_query)
|
||||||
|
return schema_linking_sql_combo_example_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,32 +7,37 @@ import sys
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from sql.prompt_maker import schema_linking_exampler, sql_exampler
|
from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar
|
||||||
from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection
|
from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection
|
||||||
from sql.output_parser import schema_link_parse
|
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||||
|
|
||||||
from util.llm_instance import llm
|
from util.llm_instance import llm
|
||||||
|
from run_config import TEXT2DSL_IS_SHORTCUT
|
||||||
|
|
||||||
class Text2DSLAgent(object):
|
class Text2DSLAgent(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.schema_linking_exampler = schema_linking_exampler
|
self.schema_linking_exampler = schema_linking_exampler
|
||||||
self.sql_exampler = sql_exampler
|
self.sql_exampler = sql_exampler
|
||||||
|
|
||||||
|
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
|
||||||
|
|
||||||
self.sql_examples_vectorstore = sql_examples_vectorstore
|
self.sql_examples_vectorstore = sql_examples_vectorstore
|
||||||
self.schema_linking_example_selector = schema_linking_example_selector
|
|
||||||
self.sql_example_selector = sql_example_selector
|
self.sql_example_selector = sql_example_selector
|
||||||
|
|
||||||
self.schema_link_parse = schema_link_parse
|
self.schema_link_parse = schema_link_parse
|
||||||
|
self.combo_schema_link_parse = combo_schema_link_parse
|
||||||
|
self.combo_sql_parse = combo_sql_parse
|
||||||
|
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
|
||||||
def update_examples(self, sql_examplars, example_nums):
|
self.is_shortcut = TEXT2DSL_IS_SHORTCUT
|
||||||
self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
|
|
||||||
sql_examplars,
|
def update_examples(self, sql_examples, example_nums, is_shortcut):
|
||||||
self.schema_linking_example_selector,
|
self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
|
||||||
self.sql_example_selector,
|
sql_examples,
|
||||||
example_nums)
|
self.sql_example_selector,
|
||||||
|
example_nums)
|
||||||
|
self.is_shortcut = is_shortcut
|
||||||
|
|
||||||
def query2sql(self, query_text: str,
|
def query2sql(self, query_text: str,
|
||||||
schema : Union[dict, None] = None,
|
schema : Union[dict, None] = None,
|
||||||
@@ -53,14 +58,14 @@ class Text2DSLAgent(object):
|
|||||||
model_name = schema['modelName']
|
model_name = schema['modelName']
|
||||||
fields_list = schema['fieldNameList']
|
fields_list = schema['fieldNameList']
|
||||||
|
|
||||||
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector)
|
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector)
|
||||||
print("schema_linking_prompt->", schema_linking_prompt)
|
print("schema_linking_prompt->", schema_linking_prompt)
|
||||||
schema_link_output = self.llm(schema_linking_prompt)
|
schema_link_output = self.llm(schema_linking_prompt)
|
||||||
schema_link_str = self.schema_link_parse(schema_link_output)
|
schema_link_str = self.schema_link_parse(schema_link_output)
|
||||||
|
|
||||||
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
|
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
|
||||||
print("sql_prompt->", sql_prompt)
|
print("sql_prompt->", sql_prompt)
|
||||||
sql_output = llm(sql_prompt)
|
sql_output = self.llm(sql_prompt)
|
||||||
|
|
||||||
resp = dict()
|
resp = dict()
|
||||||
resp['query'] = query_text
|
resp['query'] = query_text
|
||||||
@@ -69,7 +74,7 @@ class Text2DSLAgent(object):
|
|||||||
resp['priorSchemaLinking'] = linking
|
resp['priorSchemaLinking'] = linking
|
||||||
resp['dataDate'] = current_date
|
resp['dataDate'] = current_date
|
||||||
|
|
||||||
resp['schemaLinkingOutput'] = schema_link_output
|
resp['analysisOutput'] = schema_link_output
|
||||||
resp['schemaLinkStr'] = schema_link_str
|
resp['schemaLinkStr'] = schema_link_str
|
||||||
|
|
||||||
resp['sqlOutput'] = sql_output
|
resp['sqlOutput'] = sql_output
|
||||||
@@ -78,5 +83,57 @@ class Text2DSLAgent(object):
|
|||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
def query2sqlcombo(self, query_text: str,
|
||||||
|
schema : Union[dict, None] = None,
|
||||||
|
current_date: str = None,
|
||||||
|
linking: Union[List[Mapping[str, str]], None] = None
|
||||||
|
):
|
||||||
|
|
||||||
|
print("query_text: ", query_text)
|
||||||
|
print("schema: ", schema)
|
||||||
|
print("current_date: ", current_date)
|
||||||
|
print("prior_schema_links: ", linking)
|
||||||
|
|
||||||
|
if linking is not None:
|
||||||
|
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||||
|
else:
|
||||||
|
prior_schema_links = {}
|
||||||
|
|
||||||
|
model_name = schema['modelName']
|
||||||
|
fields_list = schema['fieldNameList']
|
||||||
|
|
||||||
|
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list,
|
||||||
|
prior_schema_links, self.sql_example_selector)
|
||||||
|
print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
|
||||||
|
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
|
||||||
|
|
||||||
|
schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output)
|
||||||
|
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
|
||||||
|
|
||||||
|
resp = dict()
|
||||||
|
resp['query'] = query_text
|
||||||
|
resp['model'] = model_name
|
||||||
|
resp['fields'] = fields_list
|
||||||
|
resp['priorSchemaLinking'] = prior_schema_links
|
||||||
|
resp['dataDate'] = current_date
|
||||||
|
|
||||||
|
resp['analysisOutput'] = schema_linking_sql_combo_output
|
||||||
|
resp['schemaLinkStr'] = schema_linking_str
|
||||||
|
resp['sqlOutput'] = sql_str
|
||||||
|
|
||||||
|
print("resp: ", resp)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def query2sql_run(self, query_text: str,
|
||||||
|
schema : Union[dict, None] = None,
|
||||||
|
current_date: str = None,
|
||||||
|
linking: Union[List[Mapping[str, str]], None] = None):
|
||||||
|
|
||||||
|
if self.is_shortcut:
|
||||||
|
return self.query2sqlcombo(query_text, schema, current_date, linking)
|
||||||
|
else:
|
||||||
|
return self.query2sql(query_text, schema, current_date, linking)
|
||||||
|
|
||||||
text2sql_agent = Text2DSLAgent()
|
text2sql_agent = Text2DSLAgent()
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ async def din_query2sql(query_body: Mapping[str, Any]):
|
|||||||
else:
|
else:
|
||||||
linking = query_body['linking']
|
linking = query_body['linking']
|
||||||
|
|
||||||
resp = text2sql_agent.query2sql(query_text=query_text,
|
resp = text2sql_agent.query2sql_run(query_text=query_text,
|
||||||
schema=schema, current_date=current_date, linking=linking)
|
schema=schema, current_date=current_date, linking=linking)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
@@ -70,7 +70,12 @@ async def query2sql_setting_update(query_body: Mapping[str, Any]):
|
|||||||
else:
|
else:
|
||||||
example_nums = query_body['exampleNums']
|
example_nums = query_body['exampleNums']
|
||||||
|
|
||||||
text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums)
|
if 'isShortcut' not in query_body:
|
||||||
|
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
|
||||||
|
else:
|
||||||
|
is_shortcut = query_body['isShortcut']
|
||||||
|
|
||||||
|
text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|||||||
@@ -40,4 +40,5 @@ public class DefaultSemanticConfig {
|
|||||||
|
|
||||||
@Value("${explain.path:/api/semantic/query/explain}")
|
@Value("${explain.path:/api/semantic/query/explain}")
|
||||||
private String explainPath;
|
private String explainPath;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,8 +81,9 @@ public class LocalSemanticLayer extends BaseSemanticLayer {
|
|||||||
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
|
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
|
||||||
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
|
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
|
||||||
filter.setModelIds(ids);
|
filter.setModelIds(ids);
|
||||||
modelService = ContextUtils.getBean(ModelService.class);
|
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||||
return modelService.fetchModelSchema(filter);
|
User user = User.getFakeUser();
|
||||||
|
return schemaService.fetchModelSchema(filter, user);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ public class SchemaService {
|
|||||||
|
|
||||||
|
|
||||||
public static final String ALL_CACHE = "all";
|
public static final String ALL_CACHE = "all";
|
||||||
private static final Integer META_CACHE_TIME = 5;
|
private static final Integer META_CACHE_TIME = 2;
|
||||||
private SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
private SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||||
|
|
||||||
private LoadingCache<String, SemanticSchema> cache = CacheBuilder.newBuilder()
|
private LoadingCache<String, SemanticSchema> cache = CacheBuilder.newBuilder()
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 107 KiB |
@@ -5,21 +5,25 @@ text2sql的功能实现,高度依赖对LLM的应用。通过LLM生成SQL的过
|
|||||||
|
|
||||||
### **配置方式**
|
### **配置方式**
|
||||||
1. 样本池的配置。
|
1. 样本池的配置。
|
||||||
- supersonic/chat/core/src/main/python/llm/few_shot_example/sql_exampler.py为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。
|
- supersonic/chat/core/src/main/python/few_shot_example/sql_exampler.py 为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。
|
||||||
2. 样本数量的配置。
|
2. 样本数量的配置。
|
||||||
- 在supersonic/chat/core/src/main/python/llm/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。
|
- 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。
|
||||||
- 默认值为15,为项目在内部实践后较优的经验值。样本少太少,对导致LLM在生成SQL的过程中缺少引导和示范,生成的SQL会更不稳定;样本太多,会增加生成SQL需要的时间和LLM的token消耗(或超过LLM的token上限)。
|
- 默认值为15,为项目在内部实践后较优的经验值。样本少太少,对导致LLM在生成SQL的过程中缺少引导和示范,生成的SQL会更不稳定;样本太多,会增加生成SQL需要的时间和LLM的token消耗(或超过LLM的token上限)。
|
||||||
- <div align="left" >
|
3. SQL生成方式的配置
|
||||||
|
- 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_IS_SHORTCUT 变量进行配置。
|
||||||
|
- 默认值为False;当为False时,会调用2次LLM生成SQL;当为True时,会只调用1次LLM生成SQL。相较于2次LLM调用生成的SQL,耗时会减少30-40%,token的消耗量会减少30%左右,但生成的SQL正确率会有所下降。
|
||||||
|
<div align="left" >
|
||||||
<img src=../images/text2sql_config.png width="70%"/>
|
<img src=../images/text2sql_config.png width="70%"/>
|
||||||
<p>图1-1 样本数量的配置文件</p>
|
<p>图1-1 配置文件</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
3. 运行中更新配置的脚本。
|
### **运行中更新配置的脚本**
|
||||||
- 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过脚本 supersonic/chat/core/src/main/python/bin/text2sql_resetting.sh 在项目运行中让配置生效。
|
1. 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过以下2种方式让配置在项目运行中让配置生效。
|
||||||
|
- 执行 supersonic-daemon.sh reload llmparser
|
||||||
|
- 执行 python examples_reload_run.py
|
||||||
### **FAQ**
|
### **FAQ**
|
||||||
1. 生成一个SQL需要消耗的的LLM token数量太多了,按照openAI对token的收费标准,生成一个SQL太贵了,可以少用一些token吗?
|
1. 生成一个SQL需要消耗的的LLM token数量太多了,按照openAI对token的收费标准,生成一个SQL太贵了,可以少用一些token吗?
|
||||||
- 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。
|
- 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。配置方式3.中所示,只调用1次LLM生成SQL。
|
||||||
- 需要注意,样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。
|
- 需要注意,样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ public class QueryStat {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public QueryStat setClassId(Long modelId) {
|
public QueryStat setModelId(Long modelId) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.semantic.api.query.request;
|
package com.tencent.supersonic.semantic.api.query.request;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
@@ -11,6 +12,7 @@ public class ItemUseReq {
|
|||||||
|
|
||||||
private String startTime;
|
private String startTime;
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
|
private List<Long> modelIds;
|
||||||
private Boolean cacheEnable = true;
|
private Boolean cacheEnable = true;
|
||||||
private String metric;
|
private String metric;
|
||||||
|
|
||||||
@@ -18,4 +20,8 @@ public class ItemUseReq {
|
|||||||
this.startTime = startTime;
|
this.startTime = startTime;
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
}
|
}
|
||||||
|
public ItemUseReq(String startTime, List<Long> modelIds) {
|
||||||
|
this.startTime = startTime;
|
||||||
|
this.modelIds = modelIds;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,44 +4,43 @@ import com.alibaba.fastjson.JSONObject;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.request.ModelReq;
|
import com.tencent.supersonic.semantic.api.model.request.ModelReq;
|
||||||
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
|
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DatabaseResp;
|
import com.tencent.supersonic.semantic.api.model.response.DatabaseResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DatasourceResp;
|
import com.tencent.supersonic.semantic.api.model.response.DatasourceResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.model.domain.DatabaseService;
|
import com.tencent.supersonic.semantic.model.domain.DatabaseService;
|
||||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
|
||||||
import com.tencent.supersonic.semantic.model.domain.DomainService;
|
|
||||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
|
||||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
|
||||||
import com.tencent.supersonic.semantic.model.domain.DatasourceService;
|
import com.tencent.supersonic.semantic.model.domain.DatasourceService;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.DomainService;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||||
import com.tencent.supersonic.semantic.model.domain.dataobject.ModelDO;
|
import com.tencent.supersonic.semantic.model.domain.dataobject.ModelDO;
|
||||||
import com.tencent.supersonic.semantic.model.domain.pojo.Model;
|
import com.tencent.supersonic.semantic.model.domain.pojo.Model;
|
||||||
import com.tencent.supersonic.semantic.model.domain.repository.ModelRepository;
|
import com.tencent.supersonic.semantic.model.domain.repository.ModelRepository;
|
||||||
import com.tencent.supersonic.semantic.model.domain.utils.ModelConvert;
|
import com.tencent.supersonic.semantic.model.domain.utils.ModelConvert;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Date;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Date;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
|
|||||||
@@ -67,8 +67,11 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object queryBySql(QueryDslReq querySqlCmd, User user) throws Exception {
|
public Object queryBySql(QueryDslReq querySqlCmd, User user) throws Exception {
|
||||||
|
statUtils.initStatInfo(querySqlCmd, user);
|
||||||
QueryStatement queryStatement = convertToQueryStatement(querySqlCmd, user);
|
QueryStatement queryStatement = convertToQueryStatement(querySqlCmd, user);
|
||||||
return semanticQueryEngine.execute(queryStatement);
|
QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement);
|
||||||
|
statUtils.statInfo2DbAsync(TaskStatusEnum.SUCCESS);
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception {
|
private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception {
|
||||||
|
|||||||
@@ -55,7 +55,10 @@ public class SchemaServiceImpl implements SchemaService {
|
|||||||
@Override
|
@Override
|
||||||
public List<ModelSchemaResp> fetchModelSchema(ModelSchemaFilterReq filter, User user) {
|
public List<ModelSchemaResp> fetchModelSchema(ModelSchemaFilterReq filter, User user) {
|
||||||
List<ModelSchemaResp> domainSchemaDescList = modelService.fetchModelSchema(filter);
|
List<ModelSchemaResp> domainSchemaDescList = modelService.fetchModelSchema(filter);
|
||||||
List<ItemUseResp> statInfos = queryService.getStatInfo(new ItemUseReq());
|
ItemUseReq itemUseCommend = new ItemUseReq();
|
||||||
|
itemUseCommend.setModelIds(filter.getModelIds());
|
||||||
|
|
||||||
|
List<ItemUseResp> statInfos = queryService.getStatInfo(itemUseCommend);
|
||||||
log.debug("statInfos:{}", statInfos);
|
log.debug("statInfos:{}", statInfos);
|
||||||
fillCnt(domainSchemaDescList, statInfos);
|
fillCnt(domainSchemaDescList, statInfos);
|
||||||
return domainSchemaDescList;
|
return domainSchemaDescList;
|
||||||
|
|||||||
@@ -4,22 +4,30 @@ import com.alibaba.ttl.TransmittableThreadLocal;
|
|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeBackEnum;
|
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeBackEnum;
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
|
||||||
import com.tencent.supersonic.semantic.api.model.pojo.QueryStat;
|
import com.tencent.supersonic.semantic.api.model.pojo.QueryStat;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.ItemUseReq;
|
import com.tencent.supersonic.semantic.api.query.request.ItemUseReq;
|
||||||
|
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.response.ItemUseResp;
|
import com.tencent.supersonic.semantic.api.query.response.ItemUseResp;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||||
import com.tencent.supersonic.semantic.query.persistence.repository.StatRepository;
|
import com.tencent.supersonic.semantic.query.persistence.repository.StatRepository;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.codec.digest.DigestUtils;
|
import org.apache.commons.codec.digest.DigestUtils;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -28,13 +36,17 @@ public class StatUtils {
|
|||||||
private static final TransmittableThreadLocal<QueryStat> STATS = new TransmittableThreadLocal<>();
|
private static final TransmittableThreadLocal<QueryStat> STATS = new TransmittableThreadLocal<>();
|
||||||
private final StatRepository statRepository;
|
private final StatRepository statRepository;
|
||||||
private final SqlFilterUtils sqlFilterUtils;
|
private final SqlFilterUtils sqlFilterUtils;
|
||||||
|
|
||||||
|
private final ModelService modelService;
|
||||||
private final ObjectMapper objectMapper = new ObjectMapper();
|
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||||
|
|
||||||
public StatUtils(StatRepository statRepository,
|
public StatUtils(StatRepository statRepository,
|
||||||
SqlFilterUtils sqlFilterUtils) {
|
SqlFilterUtils sqlFilterUtils,
|
||||||
|
ModelService modelService) {
|
||||||
|
|
||||||
this.statRepository = statRepository;
|
this.statRepository = statRepository;
|
||||||
this.sqlFilterUtils = sqlFilterUtils;
|
this.sqlFilterUtils = sqlFilterUtils;
|
||||||
|
this.modelService = modelService;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static QueryStat get() {
|
public static QueryStat get() {
|
||||||
@@ -69,6 +81,44 @@ public class StatUtils {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void initStatInfo(QueryDslReq queryDslReq, User facadeUser) {
|
||||||
|
QueryStat queryStatInfo = new QueryStat();
|
||||||
|
List<String> allFields = SqlParserSelectHelper.getAllFields(queryDslReq.getSql());
|
||||||
|
queryStatInfo.setModelId(queryDslReq.getModelId());
|
||||||
|
ModelSchemaResp modelSchemaResp = modelService.fetchSingleModelSchema(queryDslReq.getModelId());
|
||||||
|
|
||||||
|
List<String> dimensions = new ArrayList<>();
|
||||||
|
if (Objects.nonNull(modelSchemaResp)) {
|
||||||
|
dimensions = getFieldNames(allFields, modelSchemaResp.getDimensions());
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> metrics = new ArrayList<>();
|
||||||
|
if (Objects.nonNull(modelSchemaResp)) {
|
||||||
|
metrics = getFieldNames(allFields, modelSchemaResp.getMetrics());
|
||||||
|
}
|
||||||
|
|
||||||
|
String userName = getUserName(facadeUser);
|
||||||
|
try {
|
||||||
|
queryStatInfo.setTraceId("")
|
||||||
|
.setModelId(queryDslReq.getModelId())
|
||||||
|
.setUser(userName)
|
||||||
|
.setQueryType(QueryTypeEnum.SQL.getValue())
|
||||||
|
.setQueryTypeBack(QueryTypeBackEnum.NORMAL.getState())
|
||||||
|
.setQuerySqlCmd(queryDslReq.toString())
|
||||||
|
.setQuerySqlCmdMd5(DigestUtils.md5Hex(queryDslReq.toString()))
|
||||||
|
.setStartTime(System.currentTimeMillis())
|
||||||
|
.setUseResultCache(true)
|
||||||
|
.setUseSqlCache(true)
|
||||||
|
.setMetrics(objectMapper.writeValueAsString(metrics))
|
||||||
|
.setDimensions(objectMapper.writeValueAsString(dimensions));
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
log.error("initStatInfo:{}", e);
|
||||||
|
}
|
||||||
|
StatUtils.set(queryStatInfo);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
public void initStatInfo(QueryStructReq queryStructCmd, User facadeUser) {
|
public void initStatInfo(QueryStructReq queryStructCmd, User facadeUser) {
|
||||||
QueryStat queryStatInfo = new QueryStat();
|
QueryStat queryStatInfo = new QueryStat();
|
||||||
String traceId = "";
|
String traceId = "";
|
||||||
@@ -76,12 +126,11 @@ public class StatUtils {
|
|||||||
|
|
||||||
List<String> metrics = new ArrayList<>();
|
List<String> metrics = new ArrayList<>();
|
||||||
queryStructCmd.getAggregators().stream().forEach(aggregator -> metrics.add(aggregator.getColumn()));
|
queryStructCmd.getAggregators().stream().forEach(aggregator -> metrics.add(aggregator.getColumn()));
|
||||||
String user = (Objects.nonNull(facadeUser) && Strings.isNotEmpty(facadeUser.getName())) ? facadeUser.getName()
|
String user = getUserName(facadeUser);
|
||||||
: "Admin";
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
queryStatInfo.setTraceId(traceId)
|
queryStatInfo.setTraceId(traceId)
|
||||||
.setClassId(queryStructCmd.getModelId())
|
.setModelId(queryStructCmd.getModelId())
|
||||||
.setUser(user)
|
.setUser(user)
|
||||||
.setQueryType(QueryTypeEnum.STRUCT.getValue())
|
.setQueryType(QueryTypeEnum.STRUCT.getValue())
|
||||||
.setQueryTypeBack(QueryTypeBackEnum.NORMAL.getState())
|
.setQueryTypeBack(QueryTypeBackEnum.NORMAL.getState())
|
||||||
@@ -105,6 +154,25 @@ public class StatUtils {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<String> getFieldNames(List<String> allFields, List<? extends SchemaItem> schemaItems) {
|
||||||
|
Set<String> fieldNames = schemaItems
|
||||||
|
.stream()
|
||||||
|
.map(dimSchemaResp -> dimSchemaResp.getBizName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
if (!CollectionUtils.isEmpty(fieldNames)) {
|
||||||
|
return allFields.stream().filter(fieldName -> fieldNames.contains(fieldName))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getUserName(User facadeUser) {
|
||||||
|
return (Objects.nonNull(facadeUser) && Strings.isNotEmpty(facadeUser.getName())) ? facadeUser.getName()
|
||||||
|
: "Admin";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend) {
|
public List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend) {
|
||||||
return statRepository.getStatInfo(itemUseCommend);
|
return statRepository.getStatInfo(itemUseCommend);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,12 @@
|
|||||||
<if test="modelId != null">
|
<if test="modelId != null">
|
||||||
and model_id = #{modelId}
|
and model_id = #{modelId}
|
||||||
</if>
|
</if>
|
||||||
|
<if test="modelIds != null and modelIds.size() > 0">
|
||||||
|
and model_id in
|
||||||
|
<foreach item="id" collection="modelIds" open="(" separator="," close=")">
|
||||||
|
#{id}
|
||||||
|
</foreach>
|
||||||
|
</if>
|
||||||
<if test="metric != null">
|
<if test="metric != null">
|
||||||
and metrics like concat('%',#{metric},'%')
|
and metrics like concat('%',#{metric},'%')
|
||||||
</if>
|
</if>
|
||||||
|
|||||||
Reference in New Issue
Block a user