Merge branch 'master' into feature/lxw

This commit is contained in:
jolunoluo
2023-09-22 10:01:33 +08:00
87 changed files with 27565 additions and 504 deletions

View File

@@ -8,9 +8,11 @@ import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
@@ -32,14 +34,27 @@ import java.util.List;
public interface SemanticLayer {
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
List<ModelSchema> getModelSchema();
List<ModelSchema> getModelSchema(List<Long> ids);
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user);
List<DomainResp> getDomainList(User user);
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import org.apache.calcite.sql.parser.SqlParseException;
/**
@@ -14,6 +15,8 @@ public interface SemanticQuery {
QueryResult execute(User user) throws SqlParseException;
ExplainResp explain(User user);
SemanticParseInfo getParseInfo();
void setParseInfo(SemanticParseInfo parseInfo);

View File

@@ -37,6 +37,8 @@ public class SemanticParseInfo {
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
private String sql;
public Long getModelId() {
return model != null ? model.getId() : 0L;
}
@@ -46,6 +48,7 @@ public class SemanticParseInfo {
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
int len1 = o1.getName().length();

View File

@@ -16,4 +16,10 @@ public class LLMParserConfig {
@Value("${query2sql.path:/query2sql}")
private String queryToSqlPath;
@Value("${dimension.topn:5}")
private Integer dimensionTopN;
@Value("${metric.topn:5}")
private Integer metricTopN;
}

View File

@@ -39,6 +39,7 @@ import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -87,7 +88,7 @@ public class LLMDslParser implements SemanticParser {
return;
}
LLMReq llmReq = getLlmReq(queryCtx, modelId);
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
if (Objects.isNull(llmResp)) {
@@ -340,22 +341,28 @@ public class LLMDslParser implements SemanticParser {
return null;
}
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = DSLDateHelper.getReferenceDate(modelId);
llmReq.setCurrentDate(currentDate);
return llmReq;
@@ -399,12 +406,29 @@ public class LLMDslParser implements SemanticParser {
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
Set<String> results = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
return new ArrayList<>(results);
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
@@ -423,7 +447,8 @@ public class LLMDslParser implements SemanticParser {
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
return new ArrayList<>(fieldNameList);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {

View File

@@ -15,7 +15,10 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import java.util.ArrayList;
import java.util.List;
@@ -42,12 +45,10 @@ public class DslQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
LLMResp llmResp = getLlmResp();
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
QueryDslReq queryDslReq = getQueryDslReq(llmResp);
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput());
@@ -71,4 +72,30 @@ public class DslQuery extends PluginSemanticQuery {
parseInfo.setProperties(null);
return queryResult;
}
private LLMResp getLlmResp() {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
return dslParseResult.getLlmResp();
}
private QueryDslReq getQueryDslReq(LLMResp llmResp) {
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId());
return queryDslReq;
}
@Override
public ExplainResp explain(User user) {
ExplainSqlReq explainSqlReq = null;
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(getQueryDslReq(getLlmResp()))
.build();
return semanticLayer.explain(explainSqlReq, user);
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
}
}

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.query.plugin;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -17,5 +19,8 @@ public abstract class PluginSemanticQuery implements SemanticQuery {
return parseInfo;
}
@Override
public ExplainResp explain(User user) {
return null;
}
}

View File

@@ -21,8 +21,11 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
@@ -215,6 +218,22 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return queryResult;
}
@Override
public ExplainResp explain(User user) {
ExplainSqlReq explainSqlReq = null;
try {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.STRUCT)
.queryReq(convertQueryStruct())
.build();
return semanticLayer.explain(explainSqlReq, user);
} catch (Exception e) {
log.error("explain error explainSqlReq:{}", explainSqlReq, e);
}
return null;
}
public QueryResult multiStructExecute(User user) {
String queryMode = parseInfo.getQueryMode();

View File

@@ -110,17 +110,18 @@ public class ChatConfigController {
}
@PostMapping("/dimension/page")
public PageInfo<DimensionResp> getDimension(@RequestBody PageDimensionReq pageDimensionCmd,
public PageInfo<DimensionResp> getDimension(@RequestBody PageDimensionReq pageDimensionReq,
HttpServletRequest request,
HttpServletResponse response) {
return semanticLayer.getDimensionPage(pageDimensionCmd);
return semanticLayer.getDimensionPage(pageDimensionReq);
}
@PostMapping("/metric/page")
public PageInfo<MetricResp> getMetric(@RequestBody PageMetricReq pageMetrricCmd,
public PageInfo<MetricResp> getMetric(@RequestBody PageMetricReq pageMetricReq,
HttpServletRequest request,
HttpServletResponse response) {
return semanticLayer.getMetricPage(pageMetrricCmd);
User user = UserHolder.findUser(request, response);
return semanticLayer.getMetricPage(pageMetricReq, user);
}

View File

@@ -29,6 +29,7 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
@@ -37,9 +38,7 @@ import java.util.Comparator;
import java.util.Objects;
import java.util.stream.Collectors;
//import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
//import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
@@ -68,8 +67,6 @@ public class QueryServiceImpl implements QueryService {
@Autowired
private QueryResponder queryResponder;
private final String entity = "ENTITY";
@Value("${time.threshold: 100}")
private Integer timeThreshold;
@@ -113,12 +110,16 @@ public class QueryServiceImpl implements QueryService {
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
selectedParses.forEach(parseInfo -> {
if (parseInfo.getQueryMode().contains(entity)) {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.isEntityQuery(queryMode)) {
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class)
.getEntityInfo(parseInfo, queryReq.getUser());
parseInfo.setEntityInfo(entityInfo);
}
addExplainSql(queryReq, parseInfo);
});
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -145,6 +146,19 @@ public class QueryServiceImpl implements QueryService {
return parseResult;
}
private void addExplainSql(QueryReq queryReq, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
ExplainResp explain = semanticQuery.explain(queryReq.getUser());
if (Objects.isNull(explain)) {
return;
}
parseInfo.setSql(explain.getSql());
}
@Override
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(),
@@ -162,9 +176,9 @@ public class QueryServiceImpl implements QueryService {
chatCtx.setAgentId(queryReq.getAgentId());
Long startTime = System.currentTimeMillis();
QueryResult queryResult = semanticQuery.execute(queryReq.getUser());
Long endTime = System.currentTimeMillis();
if (queryResult != null) {
timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime))
timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime))
.interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build());
saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getChatId().longValue());
@@ -176,7 +190,6 @@ public class QueryServiceImpl implements QueryService {
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());
//chatService.addQuery(queryResult, chatCtx);
chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx);
queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId());
} else {
@@ -187,8 +200,8 @@ public class QueryServiceImpl implements QueryService {
}
public void saveInfo(List<StatisticsDO> timeCostDOList,
String queryText, Long queryId,
String userName, Long chatId) {
String queryText, Long queryId,
String userName, Long chatId) {
List<StatisticsDO> list = timeCostDOList.stream()
.filter(o -> o.getCost() > timeThreshold).collect(Collectors.toList());
list.forEach(o -> {
@@ -272,13 +285,6 @@ public class QueryServiceImpl implements QueryService {
dateConf.setPeriod("DAY");
queryStructReq.setDateInfo(dateConf);
queryStructReq.setLimit(20L);
// List<Aggregator> aggregators = new ArrayList<>();
// Aggregator aggregator = new Aggregator(dimensionValueReq.getQueryFilter().getBizName(),
// AggOperatorEnum.DISTINCT);
// aggregators.add(aggregator);
// queryStructReq.setAggregators(aggregators);
queryStructReq.setModelId(dimensionValueReq.getModelId());
queryStructReq.setNativeQuery(true);
List<String> groups = new ArrayList<>();

View File

@@ -15,6 +15,7 @@ CHROMA_DB_PERSIST_DIR = 'chm_db'
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSL_IS_SHORTCUT = False
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)

View File

@@ -22,10 +22,8 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
def reload_sql_example_collection(vectorstore:Chroma,
sql_examplars:List[Mapping[str, str]],
schema_linking_example_selector:SemanticSimilarityExampleSelector,
sql_example_selector:SemanticSimilarityExampleSelector,
example_nums:int
):
@@ -35,20 +33,16 @@ def reload_sql_example_collection(vectorstore:Chroma,
print("emptied sql_examples_collection size:", vectorstore._collection.count())
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
for example in sql_examplars:
schema_linking_example_selector.add_example(example)
sql_example_selector.add_example(example)
print("reloaded sql_examples_collection size:", vectorstore._collection.count())
return vectorstore, schema_linking_example_selector, sql_example_selector
return vectorstore, sql_example_selector
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
@@ -57,22 +51,14 @@ sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
if sql_examples_vectorstore._collection.count() > 0:
print("examples already in sql_vectorstore")
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
if sql_examples_vectorstore._collection.count() < len(sql_examplars):
print("sql_examplars size:", len(sql_examplars))
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
else:
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
print("sql_examplars size:", len(sql_examplars))
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())

View File

@@ -8,24 +8,22 @@ import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT
from few_shot_example.sql_exampler import examplars as sql_examplars
from run_config import LLMPARSER_HOST
from run_config import LLMPARSER_PORT
from run_config import LLMPARSER_HOST, LLMPARSER_PORT
def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str,
sql_examplars:List[Mapping[str, str]], example_nums:int):
sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool):
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
print("url: ", url)
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut}
headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
print(response.text)
if __name__ == "__main__":
arguments = sys.argv
text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT,
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM)
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT)

View File

@@ -10,4 +10,36 @@ def schema_link_parse(schema_link_output):
print(e)
schema_link_output = None
return schema_link_output
return schema_link_output
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'Schema_links:(\[.*?\])'
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
if schema_links_match:
schema_links = schema_links_match.group(1)
else:
schema_links = None
except Exception as e:
print(e)
schema_links = None
return schema_links
def combo_sql_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'SQL:(.*)'
sql_match = re.search(pattern, schema_linking_sql_combo_output)
if sql_match:
sql = sql_match.group(1)
else:
sql = None
except Exception as e:
print(e)
sql = None
return sql

View File

@@ -73,3 +73,38 @@ def sql_exampler(user_query: str,
schema_links=schema_link_str)
return sql_example_prompt
def schema_linking_sql_combo_examplar(user_query: str,
domain_name: str,
data_date : str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
example_selector: SemanticSimilarityExampleSelector) -> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}")
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"],
suffix=schema_linking_sql_combo_prompt
)
schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
current_date=data_date,
question=user_query)
return schema_linking_sql_combo_example_prompt

View File

@@ -7,32 +7,37 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sql.prompt_maker import schema_linking_exampler, sql_exampler
from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection
from sql.output_parser import schema_link_parse
from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar
from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
from util.llm_instance import llm
from run_config import TEXT2DSL_IS_SHORTCUT
class Text2DSLAgent(object):
def __init__(self):
self.schema_linking_exampler = schema_linking_exampler
self.sql_exampler = sql_exampler
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
self.sql_examples_vectorstore = sql_examples_vectorstore
self.schema_linking_example_selector = schema_linking_example_selector
self.sql_example_selector = sql_example_selector
self.schema_link_parse = schema_link_parse
self.combo_schema_link_parse = combo_schema_link_parse
self.combo_sql_parse = combo_sql_parse
self.llm = llm
def update_examples(self, sql_examplars, example_nums):
self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
sql_examplars,
self.schema_linking_example_selector,
self.sql_example_selector,
example_nums)
self.is_shortcut = TEXT2DSL_IS_SHORTCUT
def update_examples(self, sql_examples, example_nums, is_shortcut):
self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
sql_examples,
self.sql_example_selector,
example_nums)
self.is_shortcut = is_shortcut
def query2sql(self, query_text: str,
schema : Union[dict, None] = None,
@@ -53,14 +58,14 @@ class Text2DSLAgent(object):
model_name = schema['modelName']
fields_list = schema['fieldNameList']
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector)
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector)
print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output)
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
print("sql_prompt->", sql_prompt)
sql_output = llm(sql_prompt)
sql_output = self.llm(sql_prompt)
resp = dict()
resp['query'] = query_text
@@ -69,7 +74,7 @@ class Text2DSLAgent(object):
resp['priorSchemaLinking'] = linking
resp['dataDate'] = current_date
resp['schemaLinkingOutput'] = schema_link_output
resp['analysisOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
@@ -78,5 +83,57 @@ class Text2DSLAgent(object):
return resp
def query2sqlcombo(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None
):
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
model_name = schema['modelName']
fields_list = schema['fieldNameList']
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list,
prior_schema_links, self.sql_example_selector)
print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output)
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = current_date
resp['analysisOutput'] = schema_linking_sql_combo_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
print("resp: ", resp)
return resp
def query2sql_run(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None):
if self.is_shortcut:
return self.query2sqlcombo(query_text, schema, current_date, linking)
else:
return self.query2sql(query_text, schema, current_date, linking)
text2sql_agent = Text2DSLAgent()

View File

@@ -51,7 +51,7 @@ async def din_query2sql(query_body: Mapping[str, Any]):
else:
linking = query_body['linking']
resp = text2sql_agent.query2sql(query_text=query_text,
resp = text2sql_agent.query2sql_run(query_text=query_text,
schema=schema, current_date=current_date, linking=linking)
return resp
@@ -70,7 +70,12 @@ async def query2sql_setting_update(query_body: Mapping[str, Any]):
else:
example_nums = query_body['exampleNums']
text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums)
if 'isShortcut' not in query_body:
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
else:
is_shortcut = query_body['isShortcut']
text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut)
return "success"

View File

@@ -38,4 +38,7 @@ public class DefaultSemanticConfig {
@Value("${fetchModelList.path:/api/semantic/schema/model/list}")
private String fetchModelListPath;
@Value("${explain.path:/api/semantic/query/explain}")
private String explainPath;
}

View File

@@ -8,12 +8,14 @@ import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
@@ -79,8 +81,9 @@ public class LocalSemanticLayer extends BaseSemanticLayer {
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
filter.setModelIds(ids);
modelService = ContextUtils.getBean(ModelService.class);
return modelService.fetchModelSchema(filter);
schemaService = ContextUtils.getBean(SchemaService.class);
User user = User.getFakeUser();
return schemaService.fetchModelSchema(filter, user);
}
@Override
@@ -95,6 +98,12 @@ public class LocalSemanticLayer extends BaseSemanticLayer {
return schemaService.getModelList(user, authType, domainId);
}
@Override
public <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception {
queryService = ContextUtils.getBean(QueryService.class);
return queryService.explain(explainSqlReq, user);
}
@Override
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
dimensionService = ContextUtils.getBean(DimensionService.class);
@@ -102,9 +111,9 @@ public class LocalSemanticLayer extends BaseSemanticLayer {
}
@Override
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricReq) {
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricReq, User user) {
metricService = ContextUtils.getBean(MetricService.class);
return metricService.queryMetric(pageMetricReq);
return metricService.queryMetric(pageMetricReq, user);
}
}

View File

@@ -1,38 +1,44 @@
package com.tencent.supersonic.knowledge.semantic;
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
import com.alibaba.fastjson.JSON;
import com.github.pagehelper.PageInfo;
import com.google.gson.Gson;
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
import com.tencent.supersonic.auth.api.authentication.constant.UserConstants;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.ResultData;
import com.tencent.supersonic.common.pojo.ReturnCode;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.pojo.exception.CommonException;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.S2ThreadContext;
import com.tencent.supersonic.common.util.ThreadContext;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.common.pojo.exception.CommonException;
import com.tencent.supersonic.common.pojo.ResultData;
import com.tencent.supersonic.common.pojo.ReturnCode;
import java.net.URI;
import java.net.URL;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import java.util.LinkedHashMap;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
@@ -45,11 +51,6 @@ import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
@Slf4j
public class RemoteSemanticLayer extends BaseSemanticLayer {
@@ -61,6 +62,10 @@ public class RemoteSemanticLayer extends BaseSemanticLayer {
new ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>>() {
};
private ParameterizedTypeReference<ResultData<ExplainResp>> explainTypeRef =
new ParameterizedTypeReference<ResultData<ExplainResp>>() {
};
@Override
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
@@ -130,9 +135,10 @@ public class RemoteSemanticLayer extends BaseSemanticLayer {
fillToken(headers);
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchModelSchemaPath()).build()
.encode().toUri();
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
String fetchModelSchemaPath = defaultSemanticConfig.getFetchModelSchemaPath();
URI requestUrl = UriComponentsBuilder.fromHttpUrl(semanticUrl + fetchModelSchemaPath)
.build().encode().toUri();
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
filter.setModelIds(ids);
ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>> responseTypeRef =
@@ -179,6 +185,39 @@ public class RemoteSemanticLayer extends BaseSemanticLayer {
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), ModelResp.class);
}
@Override
public <T> ExplainResp explain(ExplainSqlReq<T> explainResp, User user) throws Exception {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
String explainPath = defaultSemanticConfig.getExplainPath();
URL url = new URL(new URL(semanticUrl), explainPath);
return explain(url.toString(), JsonUtil.toString(explainResp));
}
public ExplainResp explain(String url, String jsonReq) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
fillToken(headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
log.info("url:{},explain:{}", url, entity.getBody());
ResultData<ExplainResp> responseBody;
try {
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<ResultData<ExplainResp>> responseEntity = restTemplate.exchange(
requestUrl, HttpMethod.POST, entity, explainTypeRef);
log.info("ApiResponse<ExplainResp> responseBody:{}", responseEntity);
responseBody = responseEntity.getBody();
if (Objects.nonNull(responseBody.getData())) {
return responseBody.getData();
}
return null;
} catch (Exception e) {
throw new RuntimeException("explain interface error,url:" + url, e);
}
}
public Object fetchHttpResult(String url, String bodyJson, HttpMethod httpMethod) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
@@ -219,7 +258,7 @@ public class RemoteSemanticLayer extends BaseSemanticLayer {
}
@Override
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd) {
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user) {
String body = JsonUtil.toString(pageMetricCmd);
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
log.info("url:{}", defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath());

View File

@@ -18,7 +18,7 @@ public class SchemaService {
public static final String ALL_CACHE = "all";
private static final Integer META_CACHE_TIME = 5;
private static final Integer META_CACHE_TIME = 2;
private SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
private LoadingCache<String, SemanticSchema> cache = CacheBuilder.newBuilder()