mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
(improvement)(headless)Refactor LLMParser impl naming and structure.
This commit is contained in:
@@ -11,7 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.PythonLLMProxy;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.PythonLLMProxy;
|
||||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import java.util.stream.Collectors;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class KnowledgeService {
|
public class KnowledgeBaseService {
|
||||||
|
|
||||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||||
|
|
||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper;
|
|||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -36,7 +36,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private KnowledgeService knowledgeService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
@@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
String detectSegment, int offset) {
|
String detectSegment, int offset) {
|
||||||
// step1. pre search
|
// step1. pre search
|
||||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
// step2. suffix search
|
// step2. suffix search
|
||||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment,
|
||||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -29,7 +29,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
private static final int SEARCH_SIZE = 3;
|
private static final int SEARCH_SIZE = 3;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private KnowledgeService knowledgeService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||||
@@ -57,9 +57,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
String detectSegment = text.substring(detectIndex);
|
String detectSegment = text.substring(detectIndex);
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
||||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser;
|
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlGeneration;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlGenerationFactory;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* LLMProxy based on langchain4j Java version.
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
@Component
|
|
||||||
public class JavaLLMProxy implements LLMProxy {
|
|
||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isSkip(QueryContext queryContext) {
|
|
||||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
|
||||||
if (Objects.isNull(chatLanguageModel)) {
|
|
||||||
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
|
||||||
|
|
||||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
|
||||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
|
||||||
String modelName = llmReq.getSchema().getDataSetName();
|
|
||||||
LLMResp result = sqlGeneration.generation(llmReq, dataSetId);
|
|
||||||
result.setQuery(llmReq.getQueryText());
|
|
||||||
result.setModelName(modelName);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class SqlExample {
|
public class Exemplar {
|
||||||
|
|
||||||
private String question;
|
private String question;
|
||||||
|
|
||||||
@@ -27,29 +27,29 @@ import java.util.stream.Collectors;
|
|||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
public class SqlExamplarLoader {
|
public class ExemplarManager {
|
||||||
|
|
||||||
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
|
||||||
|
|
||||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
||||||
};
|
};
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
public List<SqlExample> getSqlExamples() throws IOException {
|
public List<Exemplar> getExemplars() throws IOException {
|
||||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||||
InputStream inputStream = resource.getInputStream();
|
InputStream inputStream = resource.getInputStream();
|
||||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
|
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
||||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||||
for (int i = 0; i < sqlExamples.size(); i++) {
|
for (int i = 0; i < exemplars.size(); i++) {
|
||||||
SqlExample sqlExample = sqlExamples.get(i);
|
Exemplar exemplar = exemplars.get(i);
|
||||||
String question = sqlExample.getQuestion();
|
String question = exemplar.getQuestion();
|
||||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class);
|
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||||
embeddingQuery.setQueryId(String.valueOf(i));
|
embeddingQuery.setQueryId(String.valueOf(i));
|
||||||
embeddingQuery.setQuery(question);
|
embeddingQuery.setQuery(question);
|
||||||
@@ -59,7 +59,7 @@ public class SqlExamplarLoader {
|
|||||||
s2EmbeddingStore.addQuery(collectionName, queries);
|
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
|
public List<Map<String, String>> recallExemplars(String queryText, int maxResults) {
|
||||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||||
.queryEmbeddings(null).build();
|
.queryEmbeddings(null).build();
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||||
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLMProxy based on langchain4j Java version.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class JavaLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
public LLMResp text2sql(LLMReq llmReq) {
|
||||||
|
|
||||||
|
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(
|
||||||
|
SqlGenType.getMode(llmReq.getSqlGenerationMode()));
|
||||||
|
String modelName = llmReq.getSchema().getDataSetName();
|
||||||
|
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||||
|
result.setQuery(llmReq.getQueryText());
|
||||||
|
result.setModelName(modelName);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser;
|
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
|
|
||||||
@@ -12,8 +11,6 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
|||||||
*/
|
*/
|
||||||
public interface LLMProxy {
|
public interface LLMProxy {
|
||||||
|
|
||||||
boolean isSkip(QueryContext queryContext);
|
LLMResp text2sql(LLMReq llmReq);
|
||||||
|
|
||||||
LLMResp query2sql(LLMReq llmReq, Long dataSetId);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -45,13 +45,12 @@ public class LLMRequestService {
|
|||||||
log.info("not enable llm, skip");
|
log.info("not enable llm, skip");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (ComponentFactory.getLLMProxy().isSkip(queryCtx)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
|
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +71,7 @@ public class LLMRequestService {
|
|||||||
llmReq.setFilterCondition(filterCondition);
|
llmReq.setFilterCondition(filterCondition);
|
||||||
|
|
||||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||||
|
llmSchema.setDataSetId(dataSetId);
|
||||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||||
|
|
||||||
@@ -95,13 +95,13 @@ public class LLMRequestService {
|
|||||||
currentDate = DateUtils.getBeforeDate(0);
|
currentDate = DateUtils.getBeforeDate(0);
|
||||||
}
|
}
|
||||||
llmReq.setCurrentDate(currentDate);
|
llmReq.setCurrentDate(currentDate);
|
||||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
|
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName());
|
||||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||||
return llmReq;
|
return llmReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp requestLLM(LLMReq llmReq, Long dataSetId) {
|
public LLMResp invokeLLM(LLMReq llmReq) {
|
||||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, dataSetId);
|
return ComponentFactory.getLLMProxy().text2sql(llmReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||||
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
|
LLMResp llmResp = requestService.invokeLLM(llmReq);
|
||||||
if (Objects.isNull(llmResp)) {
|
if (Objects.isNull(llmResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -21,21 +20,21 @@ import java.util.stream.Collectors;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OnePassSCSqlGeneration extends BaseSqlGeneration {
|
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
public LLMResp generate(LLMReq llmReq) {
|
||||||
//1.retriever sqlExamples and generate exampleListPool
|
//1.retriever sqlExamples and generate exampleListPool
|
||||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||||
|
|
||||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||||
|
|
||||||
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
||||||
List<String> linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||||
List<String> llmResults = new CopyOnWriteArrayList<>();
|
List<String> llmResults = new CopyOnWriteArrayList<>();
|
||||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||||
@@ -67,6 +66,6 @@ public class OnePassSCSqlGeneration extends BaseSqlGeneration {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void afterPropertiesSet() {
|
public void afterPropertiesSet() {
|
||||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
@@ -19,17 +18,17 @@ import java.util.Map;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OnePassSqlGeneration extends BaseSqlGeneration {
|
public class OnePassSqlGenStrategy extends SqlGenStrategy {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
public LLMResp generate(LLMReq llmReq) {
|
||||||
//1.retriever sqlExamples
|
//1.retriever sqlExamples
|
||||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||||
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||||
@@ -52,6 +51,6 @@ public class OnePassSqlGeneration extends BaseSqlGeneration {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void afterPropertiesSet() {
|
public void afterPropertiesSet() {
|
||||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this);
|
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -13,9 +13,6 @@ import java.util.regex.Matcher;
|
|||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/***
|
|
||||||
* output format
|
|
||||||
*/
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OutputFormat {
|
public class OutputFormat {
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import java.util.Collections;
|
|||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlPromptGenerator {
|
public class PromptGenerator {
|
||||||
|
|
||||||
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||||
String instruction =
|
String instruction =
|
||||||
@@ -1,15 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser;
|
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.core.config.LLMParserConfig;
|
import com.tencent.supersonic.headless.core.config.LLMParserConfig;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.OutputFormat;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.MapUtils;
|
import org.apache.commons.collections4.MapUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
@@ -30,22 +27,12 @@ import java.util.ArrayList;
|
|||||||
@Component
|
@Component
|
||||||
public class PythonLLMProxy implements LLMProxy {
|
public class PythonLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger(PythonLLMProxy.class);
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
@Override
|
public LLMResp text2sql(LLMReq llmReq) {
|
||||||
public boolean isSkip(QueryContext queryContext) {
|
|
||||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
|
||||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
|
||||||
log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.info("requestLLM request, dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
log.info("requestLLM request, llmReq:{}", llmReq);
|
||||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||||
try {
|
try {
|
||||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||||
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -10,22 +12,27 @@ import org.springframework.beans.factory.InitializingBean;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SqlGenStrategy abstracts generation step so that
|
||||||
|
* different LLM prompting strategies can be implemented.
|
||||||
|
*/
|
||||||
@Service
|
@Service
|
||||||
public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean {
|
public abstract class SqlGenStrategy implements InitializingBean {
|
||||||
|
|
||||||
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
protected SqlExamplarLoader sqlExamplarLoader;
|
protected ExemplarManager exemplarManager;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
protected OptimizationConfig optimizationConfig;
|
protected OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
protected SqlPromptGenerator sqlPromptGenerator;
|
protected PromptGenerator promptGenerator;
|
||||||
|
|
||||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||||
return S2ChatModelProvider.provide(llmConfig);
|
return S2ChatModelProvider.provide(llmConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract LLMResp generate(LLMReq llmReq);
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public class SqlGenStrategyFactory {
|
||||||
|
|
||||||
|
private static Map<LLMReq.SqlGenType, SqlGenStrategy> sqlGenStrategyMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
public static SqlGenStrategy get(LLMReq.SqlGenType strategyType) {
|
||||||
|
return sqlGenStrategyMap.get(strategyType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) {
|
||||||
|
sqlGenStrategyMap.put(strategy, sqlGenStrategy);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sql Generation interface, generating SQL using a large model.
|
|
||||||
*/
|
|
||||||
public interface SqlGeneration {
|
|
||||||
|
|
||||||
/***
|
|
||||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
|
||||||
* @param llmReq
|
|
||||||
* @param dataSetId
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
LLMResp generation(LLMReq llmReq, Long dataSetId);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
public class SqlGenerationFactory {
|
|
||||||
|
|
||||||
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
public static SqlGeneration get(SqlGenerationMode strategyType) {
|
|
||||||
return sqlGenerationMap.get(strategyType);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
|
|
||||||
sqlGenerationMap.put(strategy, sqlGeneration);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -18,20 +18,20 @@ import java.util.Map;
|
|||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
public LLMResp generate(LLMReq llmReq) {
|
||||||
//1.retriever sqlExamples and generate exampleListPool
|
//1.retriever sqlExamples and generate exampleListPool
|
||||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||||
|
|
||||||
//2.generator linking prompt,and parallel generate response.
|
//2.generator linking prompt,and parallel generate response.
|
||||||
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||||
linkingPromptPool.parallelStream().forEach(
|
linkingPromptPool.parallelStream().forEach(
|
||||||
@@ -47,7 +47,7 @@ public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
|||||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
|
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
|
||||||
//3.generator sql prompt,and parallel generate response.
|
//3.generator sql prompt,and parallel generate response.
|
||||||
List<String> sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
List<String> sqlPromptPool = promptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
||||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||||
@@ -69,6 +69,6 @@ public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void afterPropertiesSet() {
|
public void afterPropertiesSet() {
|
||||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -18,15 +18,15 @@ import java.util.Map;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
public LLMResp generate(LLMReq llmReq) {
|
||||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||||
@@ -34,7 +34,7 @@ public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
|||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||||
keyPipelineLog.info("step one model response:{}", response.content().text());
|
keyPipelineLog.info("step one model response:{}", response.content().text());
|
||||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||||
|
|
||||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||||
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
|
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
|
||||||
@@ -53,6 +53,6 @@ public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void afterPropertiesSet() {
|
public void afterPropertiesSet() {
|
||||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this);
|
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -41,6 +41,8 @@ public class LLMReq {
|
|||||||
|
|
||||||
private String dataSetName;
|
private String dataSetName;
|
||||||
|
|
||||||
|
private Long dataSetId;
|
||||||
|
|
||||||
private List<String> fieldNameList;
|
private List<String> fieldNameList;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -51,7 +53,7 @@ public class LLMReq {
|
|||||||
private String tableName;
|
private String tableName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum SqlGenerationMode {
|
public enum SqlGenType {
|
||||||
|
|
||||||
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
||||||
|
|
||||||
@@ -64,7 +66,7 @@ public class LLMReq {
|
|||||||
|
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
SqlGenerationMode(String name) {
|
SqlGenType(String name) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,10 +75,10 @@ public class LLMReq {
|
|||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static SqlGenerationMode getMode(String name) {
|
public static SqlGenType getMode(String name) {
|
||||||
for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) {
|
for (SqlGenType sqlGenType : SqlGenType.values()) {
|
||||||
if (sqlGenerationMode.name.equals(name)) {
|
if (sqlGenType.name.equals(name)) {
|
||||||
return sqlGenerationMode;
|
return sqlGenType;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.headless.core.config;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.service.SysParameterService;
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -65,7 +65,7 @@ public class OptimizationConfig {
|
|||||||
private boolean useLinkingValueSwitch;
|
private boolean useLinkingValueSwitch;
|
||||||
|
|
||||||
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
|
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
|
||||||
private SqlGenerationMode sqlGenerationMode;
|
private LLMReq.SqlGenType sqlGenType;
|
||||||
|
|
||||||
@Value("${s2SQL.use.switch:true}")
|
@Value("${s2SQL.use.switch:true}")
|
||||||
private boolean useS2SqlSwitch;
|
private boolean useS2SqlSwitch;
|
||||||
@@ -157,8 +157,8 @@ public class OptimizationConfig {
|
|||||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SqlGenerationMode getSqlGenerationMode() {
|
public LLMReq.SqlGenType getSqlGenType() {
|
||||||
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer getParseShowCount() {
|
public Integer getParseShowCount() {
|
||||||
@@ -177,8 +177,8 @@ public class OptimizationConfig {
|
|||||||
return targetType.cast(Integer.parseInt(value));
|
return targetType.cast(Integer.parseInt(value));
|
||||||
} else if (targetType == Boolean.class) {
|
} else if (targetType == Boolean.class) {
|
||||||
return targetType.cast(Boolean.parseBoolean(value));
|
return targetType.cast(Boolean.parseBoolean(value));
|
||||||
} else if (targetType == SqlGenerationMode.class) {
|
} else if (targetType == LLMReq.SqlGenType.class) {
|
||||||
return targetType.cast(SqlGenerationMode.valueOf(value));
|
return targetType.cast(LLMReq.SqlGenType.valueOf(value));
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("convertValue", e);
|
log.error("convertValue", e);
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.core.utils;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.core.cache.QueryCache;
|
import com.tencent.supersonic.headless.core.cache.QueryCache;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.LLMProxy;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMProxy;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.DataSetResolver;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.DataSetResolver;
|
||||||
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
|
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
|
||||||
import com.tencent.supersonic.headless.core.parser.SqlParser;
|
import com.tencent.supersonic.headless.core.parser.SqlParser;
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.server.listener;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.server.service.impl.WordService;
|
import com.tencent.supersonic.headless.server.service.impl.WordService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
@@ -21,7 +21,7 @@ import java.util.concurrent.CompletableFuture;
|
|||||||
public class ApplicationStartedListener implements CommandLineRunner {
|
public class ApplicationStartedListener implements CommandLineRunner {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private KnowledgeService knowledgeService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private WordService wordService;
|
private WordService wordService;
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ public class ApplicationStartedListener implements CommandLineRunner {
|
|||||||
|
|
||||||
List<DictWord> dictWords = wordService.getAllDictWords();
|
List<DictWord> dictWords = wordService.getAllDictWords();
|
||||||
wordService.setPreDictWords(dictWords);
|
wordService.setPreDictWords(dictWords);
|
||||||
knowledgeService.reloadAllData(dictWords);
|
knowledgeBaseService.reloadAllData(dictWords);
|
||||||
|
|
||||||
log.debug("ApplicationStartedInit end");
|
log.debug("ApplicationStartedInit end");
|
||||||
isOk = true;
|
isOk = true;
|
||||||
@@ -72,7 +72,7 @@ public class ApplicationStartedListener implements CommandLineRunner {
|
|||||||
}
|
}
|
||||||
log.info("dictWords has changed");
|
log.info("dictWords has changed");
|
||||||
wordService.setPreDictWords(dictWords);
|
wordService.setPreDictWords(dictWords);
|
||||||
knowledgeService.updateOnlineKnowledge(wordService.getAllDictWords());
|
knowledgeBaseService.updateOnlineKnowledge(wordService.getAllDictWords());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("reloadKnowledge error", e);
|
log.error("reloadKnowledge error", e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.server.listener;
|
package com.tencent.supersonic.headless.server.listener;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy;
|
||||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.headless.server.schedule.EmbeddingTask;
|
import com.tencent.supersonic.headless.server.schedule.EmbeddingTask;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package com.tencent.supersonic.headless.server.listener;
|
package com.tencent.supersonic.headless.server.listener;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlExamplarLoader;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.ExemplarManager;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlExample;
|
import com.tencent.supersonic.headless.core.chat.parser.llm.Exemplar;
|
||||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
@@ -19,7 +19,7 @@ import java.util.List;
|
|||||||
public class SqlEmbeddingListener implements CommandLineRunner {
|
public class SqlEmbeddingListener implements CommandLineRunner {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private SqlExamplarLoader sqlExamplarLoader;
|
private ExemplarManager exemplarManager;
|
||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
@@ -31,9 +31,9 @@ public class SqlEmbeddingListener implements CommandLineRunner {
|
|||||||
public void initSqlExamples() {
|
public void initSqlExamples() {
|
||||||
try {
|
try {
|
||||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
||||||
List<SqlExample> sqlExamples = sqlExamplarLoader.getSqlExamples();
|
List<Exemplar> exemplars = exemplarManager.getExemplars();
|
||||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||||
sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName);
|
exemplarManager.addExemplars(exemplars, collectionName);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("initSqlExamples error", e);
|
log.error("initSqlExamples error", e);
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
|||||||
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
|
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
|
||||||
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
|
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||||
@@ -96,7 +96,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ChatContextService chatContextService;
|
private ChatContextService chatContextService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private KnowledgeService knowledgeService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private QueryService queryService;
|
private QueryService queryService;
|
||||||
@Autowired
|
@Autowired
|
||||||
@@ -557,7 +557,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
|
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
|
||||||
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
|
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
|
||||||
//search from prefixSearch
|
//search from prefixSearch
|
||||||
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
|
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(),
|
||||||
2000, modelIdToDataSetIds, dataSetIds);
|
2000, modelIdToDataSetIds, dataSetIds);
|
||||||
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
||||||
return hanlpMapResultList.stream()
|
return hanlpMapResultList.stream()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
|
||||||
import com.tencent.supersonic.headless.core.file.FileHandler;
|
import com.tencent.supersonic.headless.core.file.FileHandler;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.DictTaskDO;
|
import com.tencent.supersonic.headless.server.persistence.dataobject.DictTaskDO;
|
||||||
import com.tencent.supersonic.headless.server.persistence.repository.DictRepository;
|
import com.tencent.supersonic.headless.server.persistence.repository.DictRepository;
|
||||||
@@ -46,7 +46,7 @@ public class DictTaskServiceImpl implements DictTaskService {
|
|||||||
DictUtils dictConverter,
|
DictUtils dictConverter,
|
||||||
DictUtils dictUtils,
|
DictUtils dictUtils,
|
||||||
FileHandler fileHandler,
|
FileHandler fileHandler,
|
||||||
KnowledgeService knowledgeService) {
|
KnowledgeBaseService knowledgeBaseService) {
|
||||||
this.dictRepository = dictRepository;
|
this.dictRepository = dictRepository;
|
||||||
this.dictConverter = dictConverter;
|
this.dictConverter = dictConverter;
|
||||||
this.dictUtils = dictUtils;
|
this.dictUtils = dictUtils;
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|||||||
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
|
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||||
import com.tencent.supersonic.headless.server.service.ChatContextService;
|
import com.tencent.supersonic.headless.server.service.ChatContextService;
|
||||||
@@ -58,7 +58,7 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ChatContextService chatContextService;
|
private ChatContextService chatContextService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private KnowledgeService knowledgeService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private DataSetService dataSetService;
|
private DataSetService dataSetService;
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ public class SearchServiceImpl implements SearchService {
|
|||||||
Map<Long, List<Long>> modelIdToDataSetIds =
|
Map<Long, List<Long>> modelIdToDataSetIds =
|
||||||
dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser());
|
dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser());
|
||||||
// 2.detect by segment
|
// 2.detect by segment
|
||||||
List<S2Term> originals = knowledgeService.getTerms(queryText, modelIdToDataSetIds);
|
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
|
||||||
log.info("hanlp parse result: {}", originals);
|
log.info("hanlp parse result: {}", originals);
|
||||||
Set<Long> dataSetIds = queryReq.getDataSetIds();
|
Set<Long> dataSetIds = queryReq.getDataSetIds();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user