mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(common)Rename SqlExemplar to Text2SQLExemplar.
This commit is contained in:
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
|||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
@@ -41,9 +41,9 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
|
|
||||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||||
SqlExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
Text2SQLExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||||
executeContext.getParseInfo().getProperties()
|
executeContext.getParseInfo().getProperties()
|
||||||
.get(SqlExemplar.PROPERTY_KEY)), SqlExemplar.class);
|
.get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class);
|
||||||
|
|
||||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||||
memoryService.createMemory(ChatMemoryDO.builder()
|
memoryService.createMemory(ChatMemoryDO.builder()
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ public class MemoryReviewTask {
|
|||||||
+ "please take a review and give your opinion.\n"
|
+ "please take a review and give your opinion.\n"
|
||||||
+ "#Rules: "
|
+ "#Rules: "
|
||||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||||
+ "2.ALWAYS recognize `数据日期` as the date field.\n"
|
+ "2.ALWAYS recognize `数据日期` as the date field."
|
||||||
|
+ "3.IGNORE `数据日期` if not expressed in the `Question`."
|
||||||
+ "#Question: %s\n"
|
+ "#Question: %s\n"
|
||||||
+ "#Schema: %s\n"
|
+ "#Schema: %s\n"
|
||||||
+ "#SideInfo: %s\n"
|
+ "#SideInfo: %s\n"
|
||||||
@@ -57,12 +58,12 @@ public class MemoryReviewTask {
|
|||||||
m.getSideInfo(), m.getS2sql());
|
m.getSideInfo(), m.getS2sql());
|
||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||||
chatAgent.getModelConfig());
|
chatAgent.getModelConfig());
|
||||||
if (Objects.nonNull(chatLanguageModel)) {
|
if (Objects.nonNull(chatLanguageModel)) {
|
||||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||||
|
|
||||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||||
if (matcher.find()) {
|
if (matcher.find()) {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
|||||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -207,7 +207,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||||
String errMsg, List<SqlExemplar> similarExemplars,
|
String errMsg, List<Text2SQLExemplar> similarExemplars,
|
||||||
List<String> agentExamples) {
|
List<String> agentExamples) {
|
||||||
Map<String, Object> variables = new HashMap<>();
|
Map<String, Object> variables = new HashMap<>();
|
||||||
variables.put("user_question", userQuestion);
|
variables.put("user_question", userQuestion);
|
||||||
@@ -276,7 +276,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||||
queryNLReq.getQueryText(), 5);
|
queryNLReq.getQueryText(), 5);
|
||||||
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
@@ -43,7 +43,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
|||||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
List<Text2SQLExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||||
return exemplars.stream().map(sqlExemplar ->
|
return exemplars.stream().map(sqlExemplar ->
|
||||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -100,7 +100,7 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
public void enableMemory(ChatMemoryDO memory) {
|
public void enableMemory(ChatMemoryDO memory) {
|
||||||
memory.setStatus(MemoryStatus.ENABLED);
|
memory.setStatus(MemoryStatus.ENABLED);
|
||||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
SqlExemplar.builder()
|
Text2SQLExemplar.builder()
|
||||||
.question(memory.getQuestion())
|
.question(memory.getQuestion())
|
||||||
.sideInfo(memory.getSideInfo())
|
.sideInfo(memory.getSideInfo())
|
||||||
.dbSchema(memory.getDbSchema())
|
.dbSchema(memory.getDbSchema())
|
||||||
@@ -112,7 +112,7 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
public void disableMemory(ChatMemoryDO memory) {
|
public void disableMemory(ChatMemoryDO memory) {
|
||||||
memory.setStatus(MemoryStatus.DISABLED);
|
memory.setStatus(MemoryStatus.DISABLED);
|
||||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
SqlExemplar.builder()
|
Text2SQLExemplar.builder()
|
||||||
.question(memory.getQuestion())
|
.question(memory.getQuestion())
|
||||||
.sideInfo(memory.getSideInfo())
|
.sideInfo(memory.getSideInfo())
|
||||||
.dbSchema(memory.getDbSchema())
|
.dbSchema(memory.getDbSchema())
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import lombok.NoArgsConstructor;
|
|||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class SqlExemplar {
|
public class Text2SQLExemplar {
|
||||||
|
|
||||||
public static final String PROPERTY_KEY = "sql_exemplar";
|
public static final String PROPERTY_KEY = "sql_exemplar";
|
||||||
|
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
package com.tencent.supersonic.common.service;
|
package com.tencent.supersonic.common.service;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface ExemplarService {
|
public interface ExemplarService {
|
||||||
void storeExemplar(String collection, SqlExemplar exemplar);
|
void storeExemplar(String collection, Text2SQLExemplar exemplar);
|
||||||
|
|
||||||
void removeExemplar(String collection, SqlExemplar exemplar);
|
void removeExemplar(String collection, Text2SQLExemplar exemplar);
|
||||||
|
|
||||||
List<SqlExemplar> recallExemplars(String collection, String query, int num);
|
List<Text2SQLExemplar> recallExemplars(String collection, String query, int num);
|
||||||
|
|
||||||
List<SqlExemplar> recallExemplars(String query, int num);
|
List<Text2SQLExemplar> recallExemplars(String query, int num);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ import com.fasterxml.jackson.core.type.TypeReference;
|
|||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
|
||||||
import dev.langchain4j.data.document.Metadata;
|
import dev.langchain4j.data.document.Metadata;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
@@ -31,7 +31,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
|
|
||||||
private static final String SYS_EXEMPLAR_FILE = "s2-exemplar.json";
|
private static final String SYS_EXEMPLAR_FILE = "s2-exemplar.json";
|
||||||
|
|
||||||
private TypeReference<List<SqlExemplar>> valueTypeRef = new TypeReference<List<SqlExemplar>>() {
|
private TypeReference<List<Text2SQLExemplar>> valueTypeRef = new TypeReference<List<Text2SQLExemplar>>() {
|
||||||
};
|
};
|
||||||
|
|
||||||
private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper();
|
private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper();
|
||||||
@@ -42,7 +42,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingService embeddingService;
|
private EmbeddingService embeddingService;
|
||||||
|
|
||||||
public void storeExemplar(String collection, SqlExemplar exemplar) {
|
public void storeExemplar(String collection, Text2SQLExemplar exemplar) {
|
||||||
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||||
String.class, Object.class));
|
String.class, Object.class));
|
||||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||||
@@ -51,7 +51,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void removeExemplar(String collection, SqlExemplar exemplar) {
|
public void removeExemplar(String collection, Text2SQLExemplar exemplar) {
|
||||||
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||||
String.class, Object.class));
|
String.class, Object.class));
|
||||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||||
@@ -59,20 +59,20 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
|
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SqlExemplar> recallExemplars(String query, int num) {
|
public List<Text2SQLExemplar> recallExemplars(String query, int num) {
|
||||||
String collection = embeddingConfig.getText2sqlCollectionName();
|
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||||
return recallExemplars(collection, query, num);
|
return recallExemplars(collection, query, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SqlExemplar> recallExemplars(String collection, String query, int num) {
|
public List<Text2SQLExemplar> recallExemplars(String collection, String query, int num) {
|
||||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
.queryTextsList(Lists.newArrayList(query))
|
.queryTextsList(Lists.newArrayList(query))
|
||||||
.build();
|
.build();
|
||||||
List<RetrieveQueryResult> results = embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
List<RetrieveQueryResult> results = embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||||
results.stream().forEach(ret -> {
|
results.stream().forEach(ret -> {
|
||||||
ret.getRetrieval().stream().forEach(r -> {
|
ret.getRetrieval().stream().forEach(r -> {
|
||||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), SqlExemplar.class));
|
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
private void loadSysExemplars() throws IOException {
|
private void loadSysExemplars() throws IOException {
|
||||||
ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE);
|
ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE);
|
||||||
InputStream inputStream = resource.getInputStream();
|
InputStream inputStream = resource.getInputStream();
|
||||||
List<SqlExemplar> exemplars = objectMapper.readValue(inputStream, valueTypeRef);
|
List<Text2SQLExemplar> exemplars = objectMapper.readValue(inputStream, valueTypeRef);
|
||||||
String collection = embeddingConfig.getText2sqlCollectionName();
|
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||||
exemplars.stream().forEach(e -> storeExemplar(collection, e));
|
exemplars.stream().forEach(e -> storeExemplar(collection, e));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.google.common.collect.Sets;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
@@ -29,6 +29,6 @@ public class QueryNLReq {
|
|||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private ChatModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
|
||||||
private SemanticParseInfo contextParseInfo;
|
private SemanticParseInfo contextParseInfo;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
@@ -52,7 +52,7 @@ public class ChatQueryContext {
|
|||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private ChatModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
private List<SqlExemplar> dynamicExemplars;
|
private List<Text2SQLExemplar> dynamicExemplars;
|
||||||
private SemanticParseInfo contextParseInfo;
|
private SemanticParseInfo contextParseInfo;
|
||||||
|
|
||||||
public List<SemanticQuery> getCandidateQueries() {
|
public List<SemanticQuery> getCandidateQueries() {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||||
@@ -36,13 +36,13 @@ public class LLMResponseService {
|
|||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
properties.put(Constants.CONTEXT, parseResult);
|
properties.put(Constants.CONTEXT, parseResult);
|
||||||
properties.put("type", "internal");
|
properties.put("type", "internal");
|
||||||
SqlExemplar exemplar = SqlExemplar.builder()
|
Text2SQLExemplar exemplar = Text2SQLExemplar.builder()
|
||||||
.question(queryCtx.getQueryText())
|
.question(queryCtx.getQueryText())
|
||||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||||
.sql(parseResult.getLlmResp().getSqlOutput())
|
.sql(parseResult.getLlmResp().getSqlOutput())
|
||||||
.build();
|
.build();
|
||||||
properties.put(SqlExemplar.PROPERTY_KEY, exemplar);
|
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
|
||||||
parseInfo.setProperties(properties);
|
parseInfo.setProperties(properties);
|
||||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
@@ -45,11 +45,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
llmResp.setQuery(llmReq.getQueryText());
|
llmResp.setQuery(llmReq.getQueryText());
|
||||||
//1.recall exemplars
|
//1.recall exemplars
|
||||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||||
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||||
|
|
||||||
//2.generate sql generation prompt for each self-consistency inference
|
//2.generate sql generation prompt for each self-consistency inference
|
||||||
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
|
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
|
||||||
for (List<SqlExemplar> exemplars : exemplarsList) {
|
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
|
||||||
llmReq.setDynamicExemplars(exemplars);
|
llmReq.setDynamicExemplars(exemplars);
|
||||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||||
prompt2Exemplar.put(prompt, exemplars);
|
prompt2Exemplar.put(prompt, exemplars);
|
||||||
@@ -61,9 +61,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
String result = response.content().text();
|
String sqlOutput = StringUtils.normalizeSpace(response.content().text());
|
||||||
output2Prompt.put(result, prompt);
|
output2Prompt.put(sqlOutput, prompt);
|
||||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutput);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
||||||
Lists.newArrayList(output2Prompt.keySet()));
|
Lists.newArrayList(output2Prompt.keySet()));
|
||||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||||
List<SqlExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
List<Text2SQLExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
|
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
|
||||||
|
|
||||||
return llmResp;
|
return llmResp;
|
||||||
@@ -79,7 +79,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
|
|
||||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
||||||
StringBuilder exemplars = new StringBuilder();
|
StringBuilder exemplars = new StringBuilder();
|
||||||
for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) {
|
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||||
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
|
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
|
||||||
exemplar.getQuestion(), exemplar.getDbSchema(),
|
exemplar.getQuestion(), exemplar.getDbSchema(),
|
||||||
exemplar.getSideInfo(), exemplar.getSql());
|
exemplar.getSideInfo(), exemplar.getSql());
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
@@ -29,12 +29,12 @@ public class PromptHelper {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ExemplarService exemplarService;
|
private ExemplarService exemplarService;
|
||||||
|
|
||||||
public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||||
|
|
||||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||||
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
||||||
exemplars.add(e);
|
exemplars.add(e);
|
||||||
});
|
});
|
||||||
@@ -44,10 +44,10 @@ public class PromptHelper {
|
|||||||
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<List<SqlExemplar>> results = new ArrayList<>();
|
List<List<Text2SQLExemplar>> results = new ArrayList<>();
|
||||||
// use random collection of exemplars for each self-consistency inference
|
// use random collection of exemplars for each self-consistency inference
|
||||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||||
List<SqlExemplar> shuffledList = new ArrayList<>(exemplars);
|
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||||
Collections.shuffle(shuffledList);
|
Collections.shuffle(shuffledList);
|
||||||
results.add(shuffledList.subList(0, fewShotNumber));
|
results.add(shuffledList.subList(0, fewShotNumber));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
@@ -54,7 +54,7 @@ public class ResponseHelper {
|
|||||||
return Pair.of(inputMax, votePercentage);
|
return Pair.of(inputMax, votePercentage);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
|
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Text2SQLExemplar> sqlExamples,
|
||||||
Map<String, Double> sqlMap) {
|
Map<String, Double> sqlMap) {
|
||||||
if (sqlMap == null) {
|
if (sqlMap == null) {
|
||||||
return new HashMap<>();
|
return new HashMap<>();
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ public class LLMReq {
|
|||||||
private ChatModelConfig modelConfig;
|
private ChatModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
|
|
||||||
private List<SqlExemplar> dynamicExemplars;
|
private List<Text2SQLExemplar> dynamicExemplars;
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -16,6 +16,6 @@ public class LLMSqlResp {
|
|||||||
|
|
||||||
private double sqlWeight;
|
private double sqlWeight;
|
||||||
|
|
||||||
private List<SqlExemplar> fewShots;
|
private List<Text2SQLExemplar> fewShots;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user