(improvement)(Chat) llmSqlParser is adapted for tag mode, and rule parsing filters based on the dataset query type. (#804)

This commit is contained in:
lexluo09
2024-03-12 17:03:08 +08:00
committed by GitHub
parent c2316c944d
commit bcc0f9caa9
11 changed files with 72 additions and 66 deletions

View File

@@ -20,9 +20,10 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
@@ -32,9 +33,6 @@ import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
@Component
@Slf4j
@Order(3)
@@ -170,7 +168,6 @@ public class ChatDemoLoader implements CommandLineRunner {
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
@@ -196,7 +193,6 @@ public class ChatDemoLoader implements CommandLineRunner {
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(2L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {

View File

@@ -58,7 +58,7 @@ public class TagTest extends BaseTest {
List<String> list = new ArrayList<>();
list.add("流行");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"流行", "风格", 6L);
"流行", "风格", 2L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build();