add metric and dimension name check (#149)

* (improvement)(semantic) add metric and dimension name check

* (improvement)(chat) opt QueryResponder recalling history similar solved query

---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-09-26 20:17:52 +08:00
committed by GitHub
parent 4ad3e1d9cf
commit ff5479f1a2
8 changed files with 176 additions and 67 deletions

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class SolvedQueryReq {
private Long queryId;
private Integer parseId;
private String queryText;
private Long modelId;
private Integer agentId;
}

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.queryresponder;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
@@ -42,12 +44,17 @@ public class DefaultQueryResponder implements QueryResponder {
}
@Override
public void saveSolvedQuery(String queryText, Long queryId, Integer parseId) {
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
String queryText = solvedQueryReq.getQueryText();
try {
String uniqueId = generateUniqueId(queryId, parseId);
Map<String, String> requestMap = new HashMap<>();
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("query", queryText);
requestMap.put("query_id", uniqueId);
Map<String, Object> metaData = new HashMap<>();
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
requestMap.put("metadata", metaData);
doRequest(embeddingConfig.getSolvedQueryAddPath(),
JSONObject.toJSONString(Lists.newArrayList(requestMap)));
} catch (Exception e) {
@@ -56,7 +63,7 @@ public class DefaultQueryResponder implements QueryResponder {
}
@Override
public List<SolvedQueryRecallResp> recallSolvedQuery(String queryText) {
public List<SolvedQueryRecallResp> recallSolvedQuery(String queryText, Integer agentId) {
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
try {
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results="
@@ -66,7 +73,12 @@ public class DefaultQueryResponder implements QueryResponder {
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(queryText));
Map<String, Object> map = new HashMap<>();
map.put("queryTextsList", Lists.newArrayList(queryText));
Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put("agentId", String.valueOf(agentId));
map.put("filterCondition", filterCondition);
String jsonBody = JSONObject.toJSONString(map, SerializerFeature.WriteMapNullValue);
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
ResponseEntity<List<EmbeddingResp>> embeddingResponseEntity =
@@ -80,12 +92,14 @@ public class DefaultQueryResponder implements QueryResponder {
for (EmbeddingResp embeddingResp : embeddingResps) {
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
if (queryText.equalsIgnoreCase(embeddingRetrieval.getQuery())) {
continue;
}
if (querySet.contains(embeddingRetrieval.getQuery())) {
continue;
}
String id = embeddingRetrieval.getId();
SolvedQueryRecallResp solvedQueryRecallResp =
SolvedQueryRecallResp.builder()
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
.queryText(embeddingRetrieval.getQuery())
.queryId(getQueryId(id)).parseId(getParseId(id))
.build();

View File

@@ -1,12 +1,13 @@
package com.tencent.supersonic.chat.queryresponder;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import java.util.List;
public interface QueryResponder {
void saveSolvedQuery(String queryText, Long queryId, Integer parseId);
void saveSolvedQuery(SolvedQueryReq solvedQueryReq);
List<SolvedQueryRecallResp> recallSolvedQuery(String queryText);
List<SolvedQueryRecallResp> recallSolvedQuery(String queryText, Integer agentId);
}

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
@@ -21,6 +22,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QuerySelector;
@@ -150,7 +152,7 @@ public class QueryServiceImpl implements QueryService {
.build();
}
List<SolvedQueryRecallResp> solvedQueryRecallResps =
queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText());
queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText(), queryReq.getAgentId());
parseResult.setSimilarSolvedQuery(solvedQueryRecallResps);
return parseResult;
}
@@ -172,6 +174,7 @@ public class QueryServiceImpl implements QueryService {
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(),
queryReq.getUser().getName(), queryReq.getParseId());
ChatQueryDO chatQueryDO = chatService.getLastQuery(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
@@ -196,6 +199,11 @@ public class QueryServiceImpl implements QueryService {
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) {
chatCtx.setParseInfo(parseInfo);
chatService.updateContext(chatCtx);
queryResponder.saveSolvedQuery(SolvedQueryReq.builder().parseId(queryReq.getParseId())
.queryId(queryReq.getQueryId())
.agentId(chatQueryDO.getAgentId())
.modelId(parseInfo.getModelId())
.queryText(queryReq.getQueryText()).build());
}
chatCtx.setQueryText(queryReq.getQueryText());
chatCtx.setUser(queryReq.getUser().getName());