mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(headless)Refactor LLMParser impl naming and structure.
This commit is contained in:
@@ -11,7 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.PythonLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -13,7 +13,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class KnowledgeService {
|
||||
public class KnowledgeBaseService {
|
||||
|
||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -36,7 +36,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
@@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -29,7 +29,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
@@ -57,9 +57,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
if (Objects.isNull(chatLanguageModel)) {
|
||||
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, dataSetId);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SqlExample {
|
||||
public class Exemplar {
|
||||
|
||||
private String question;
|
||||
|
||||
@@ -27,29 +27,29 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SqlExamplarLoader {
|
||||
public class ExemplarManager {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
||||
};
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<SqlExample> getSqlExamples() throws IOException {
|
||||
public List<Exemplar> getExemplars() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
}
|
||||
|
||||
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
|
||||
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||
for (int i = 0; i < sqlExamples.size(); i++) {
|
||||
SqlExample sqlExample = sqlExamples.get(i);
|
||||
String question = sqlExample.getQuestion();
|
||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class);
|
||||
for (int i = 0; i < exemplars.size(); i++) {
|
||||
Exemplar exemplar = exemplars.get(i);
|
||||
String question = exemplar.getQuestion();
|
||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(String.valueOf(i));
|
||||
embeddingQuery.setQuery(question);
|
||||
@@ -59,7 +59,7 @@ public class SqlExamplarLoader {
|
||||
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||
}
|
||||
|
||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
|
||||
public List<Map<String, String>> recallExemplars(String queryText, int maxResults) {
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||
.queryEmbeddings(null).build();
|
||||
@@ -0,0 +1,28 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(
|
||||
SqlGenType.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser;
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
@@ -12,8 +11,6 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
*/
|
||||
public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, Long dataSetId);
|
||||
LLMResp text2sql(LLMReq llmReq);
|
||||
|
||||
}
|
||||
@@ -45,13 +45,12 @@ public class LLMRequestService {
|
||||
log.info("not enable llm, skip");
|
||||
return true;
|
||||
}
|
||||
if (ComponentFactory.getLLMProxy().isSkip(queryCtx)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -72,6 +71,7 @@ public class LLMRequestService {
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setDataSetId(dataSetId);
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||
|
||||
@@ -95,13 +95,13 @@ public class LLMRequestService {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName());
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, Long dataSetId) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, dataSetId);
|
||||
public LLMResp invokeLLM(LLMReq llmReq) {
|
||||
return ComponentFactory.getLLMProxy().text2sql(llmReq);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||
|
||||
@@ -43,7 +43,7 @@ public class LLMSqlParser implements SemanticParser {
|
||||
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
|
||||
LLMResp llmResp = requestService.invokeLLM(llmReq);
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -21,21 +20,21 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGeneration extends BaseSqlGeneration {
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
||||
List<String> linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||
List<String> llmResults = new CopyOnWriteArrayList<>();
|
||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
@@ -67,6 +66,6 @@ public class OnePassSCSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -19,17 +18,17 @@ import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGeneration extends BaseSqlGeneration {
|
||||
public class OnePassSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
@@ -52,6 +51,6 @@ public class OnePassSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this);
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -13,9 +13,6 @@ import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* output format
|
||||
*/
|
||||
@Slf4j
|
||||
public class OutputFormat {
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import java.util.Collections;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class SqlPromptGenerator {
|
||||
public class PromptGenerator {
|
||||
|
||||
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction =
|
||||
@@ -1,15 +1,12 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser;
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.OutputFormat;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpEntity;
|
||||
@@ -30,22 +27,12 @@ import java.util.ArrayList;
|
||||
@Component
|
||||
public class PythonLLMProxy implements LLMProxy {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger(PythonLLMProxy.class);
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
log.info("requestLLM request, llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -10,22 +12,27 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* SqlGenStrategy abstracts generation step so that
|
||||
* different LLM prompting strategies can be implemented.
|
||||
*/
|
||||
@Service
|
||||
public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
public abstract class SqlGenStrategy implements InitializingBean {
|
||||
|
||||
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Autowired
|
||||
protected SqlExamplarLoader sqlExamplarLoader;
|
||||
protected ExemplarManager exemplarManager;
|
||||
|
||||
@Autowired
|
||||
protected OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
protected SqlPromptGenerator sqlPromptGenerator;
|
||||
protected PromptGenerator promptGenerator;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||
return S2ChatModelProvider.provide(llmConfig);
|
||||
}
|
||||
|
||||
abstract LLMResp generate(LLMReq llmReq);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class SqlGenStrategyFactory {
|
||||
|
||||
private static Map<LLMReq.SqlGenType, SqlGenStrategy> sqlGenStrategyMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static SqlGenStrategy get(LLMReq.SqlGenType strategyType) {
|
||||
return sqlGenStrategyMap.get(strategyType);
|
||||
}
|
||||
|
||||
public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) {
|
||||
sqlGenStrategyMap.put(strategy, sqlGenStrategy);
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* Sql Generation interface, generating SQL using a large model.
|
||||
*/
|
||||
public interface SqlGeneration {
|
||||
|
||||
/***
|
||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||
* @param llmReq
|
||||
* @param dataSetId
|
||||
* @return
|
||||
*/
|
||||
LLMResp generation(LLMReq llmReq, Long dataSetId);
|
||||
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class SqlGenerationFactory {
|
||||
|
||||
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static SqlGeneration get(SqlGenerationMode strategyType) {
|
||||
return sqlGenerationMap.get(strategyType);
|
||||
}
|
||||
|
||||
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
|
||||
sqlGenerationMap.put(strategy, sqlGeneration);
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -18,20 +18,20 @@ import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
||||
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
//2.generator linking prompt,and parallel generate response.
|
||||
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
@@ -47,7 +47,7 @@ public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
|
||||
//3.generator sql prompt,and parallel generate response.
|
||||
List<String> sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
||||
List<String> sqlPromptPool = promptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
@@ -69,6 +69,6 @@ public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -18,15 +18,15 @@ import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
||||
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||
@@ -34,7 +34,7 @@ public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
keyPipelineLog.info("step one model response:{}", response.content().text());
|
||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
|
||||
@@ -53,6 +53,6 @@ public class TwoPassSqlGeneration extends BaseSqlGeneration {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this);
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,8 @@ public class LLMReq {
|
||||
|
||||
private String dataSetName;
|
||||
|
||||
private Long dataSetId;
|
||||
|
||||
private List<String> fieldNameList;
|
||||
|
||||
}
|
||||
@@ -51,7 +53,7 @@ public class LLMReq {
|
||||
private String tableName;
|
||||
}
|
||||
|
||||
public enum SqlGenerationMode {
|
||||
public enum SqlGenType {
|
||||
|
||||
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
||||
|
||||
@@ -64,7 +66,7 @@ public class LLMReq {
|
||||
|
||||
private String name;
|
||||
|
||||
SqlGenerationMode(String name) {
|
||||
SqlGenType(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@@ -73,10 +75,10 @@ public class LLMReq {
|
||||
return name;
|
||||
}
|
||||
|
||||
public static SqlGenerationMode getMode(String name) {
|
||||
for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) {
|
||||
if (sqlGenerationMode.name.equals(name)) {
|
||||
return sqlGenerationMode;
|
||||
public static SqlGenType getMode(String name) {
|
||||
for (SqlGenType sqlGenType : SqlGenType.values()) {
|
||||
if (sqlGenType.name.equals(name)) {
|
||||
return sqlGenType;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.headless.core.config;
|
||||
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -65,7 +65,7 @@ public class OptimizationConfig {
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
|
||||
private SqlGenerationMode sqlGenerationMode;
|
||||
private LLMReq.SqlGenType sqlGenType;
|
||||
|
||||
@Value("${s2SQL.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
@@ -157,8 +157,8 @@ public class OptimizationConfig {
|
||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public SqlGenerationMode getSqlGenerationMode() {
|
||||
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
||||
public LLMReq.SqlGenType getSqlGenType() {
|
||||
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType);
|
||||
}
|
||||
|
||||
public Integer getParseShowCount() {
|
||||
@@ -177,8 +177,8 @@ public class OptimizationConfig {
|
||||
return targetType.cast(Integer.parseInt(value));
|
||||
} else if (targetType == Boolean.class) {
|
||||
return targetType.cast(Boolean.parseBoolean(value));
|
||||
} else if (targetType == SqlGenerationMode.class) {
|
||||
return targetType.cast(SqlGenerationMode.valueOf(value));
|
||||
} else if (targetType == LLMReq.SqlGenType.class) {
|
||||
return targetType.cast(LLMReq.SqlGenType.valueOf(value));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("convertValue", e);
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.core.utils;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.core.cache.QueryCache;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.LLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.DataSetResolver;
|
||||
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
|
||||
import com.tencent.supersonic.headless.core.parser.SqlParser;
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.server.listener;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.server.service.impl.WordService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -21,7 +21,7 @@ import java.util.concurrent.CompletableFuture;
|
||||
public class ApplicationStartedListener implements CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired
|
||||
private WordService wordService;
|
||||
|
||||
@@ -37,7 +37,7 @@ public class ApplicationStartedListener implements CommandLineRunner {
|
||||
|
||||
List<DictWord> dictWords = wordService.getAllDictWords();
|
||||
wordService.setPreDictWords(dictWords);
|
||||
knowledgeService.reloadAllData(dictWords);
|
||||
knowledgeBaseService.reloadAllData(dictWords);
|
||||
|
||||
log.debug("ApplicationStartedInit end");
|
||||
isOk = true;
|
||||
@@ -72,7 +72,7 @@ public class ApplicationStartedListener implements CommandLineRunner {
|
||||
}
|
||||
log.info("dictWords has changed");
|
||||
wordService.setPreDictWords(dictWords);
|
||||
knowledgeService.updateOnlineKnowledge(wordService.getAllDictWords());
|
||||
knowledgeBaseService.updateOnlineKnowledge(wordService.getAllDictWords());
|
||||
} catch (Exception e) {
|
||||
log.error("reloadKnowledge error", e);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.listener;
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.server.schedule.EmbeddingTask;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.headless.server.listener;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlExamplarLoader;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.SqlExample;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.ExemplarManager;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.Exemplar;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -19,7 +19,7 @@ import java.util.List;
|
||||
public class SqlEmbeddingListener implements CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
private ExemplarManager exemplarManager;
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@@ -31,9 +31,9 @@ public class SqlEmbeddingListener implements CommandLineRunner {
|
||||
public void initSqlExamples() {
|
||||
try {
|
||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
||||
List<SqlExample> sqlExamples = sqlExamplarLoader.getSqlExamples();
|
||||
List<Exemplar> exemplars = exemplarManager.getExemplars();
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName);
|
||||
exemplarManager.addExemplars(exemplars, collectionName);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("initSqlExamples error", e);
|
||||
|
||||
@@ -42,7 +42,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
|
||||
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
@@ -96,7 +96,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired
|
||||
private QueryService queryService;
|
||||
@Autowired
|
||||
@@ -557,7 +557,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = new HashMap<>();
|
||||
modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds));
|
||||
//search from prefixSearch
|
||||
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
|
||||
List<HanlpMapResult> hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(),
|
||||
2000, modelIdToDataSetIds, dataSetIds);
|
||||
HanlpHelper.transLetterOriginal(hanlpMapResultList);
|
||||
return hanlpMapResultList.stream()
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DictItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp;
|
||||
import com.tencent.supersonic.headless.core.file.FileHandler;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.DictTaskDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.DictRepository;
|
||||
@@ -46,7 +46,7 @@ public class DictTaskServiceImpl implements DictTaskService {
|
||||
DictUtils dictConverter,
|
||||
DictUtils dictUtils,
|
||||
FileHandler fileHandler,
|
||||
KnowledgeService knowledgeService) {
|
||||
KnowledgeBaseService knowledgeBaseService) {
|
||||
this.dictRepository = dictRepository;
|
||||
this.dictConverter = dictConverter;
|
||||
this.dictUtils = dictUtils;
|
||||
|
||||
@@ -18,7 +18,7 @@ import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
import com.tencent.supersonic.headless.server.service.ChatContextService;
|
||||
@@ -58,7 +58,7 @@ public class SearchServiceImpl implements SearchService {
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired
|
||||
private DataSetService dataSetService;
|
||||
|
||||
@@ -73,7 +73,7 @@ public class SearchServiceImpl implements SearchService {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds =
|
||||
dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser());
|
||||
// 2.detect by segment
|
||||
List<S2Term> originals = knowledgeService.getTerms(queryText, modelIdToDataSetIds);
|
||||
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
|
||||
log.info("hanlp parse result: {}", originals);
|
||||
Set<Long> dataSetIds = queryReq.getDataSetIds();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user