(improvement)(chat) Add 'Few-shot Examples' display to the Chat chart. (#589)

This commit is contained in:
lexluo09
2024-01-02 18:07:15 +08:00
committed by GitHub
parent e7f13572d7
commit d72166944c
13 changed files with 118 additions and 59 deletions

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
@@ -15,38 +16,40 @@ class LLMResponseServiceTest {
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp = new LLMResp();
Map<String, Double> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, 0.2D);
sqlWeight.put(sql2, 0.8D);
llmResp.setSqlWeight(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
llmResp.setSqlRespMap(sqlWeight);
LLMResponseService llmResponseService = new LLMResponseService();
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp2 = new LLMResp();
Map<String, Double> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, 0.2D);
sqlWeight2.put(sql2, 0.8D);
llmResp2.setSqlWeight(sqlWeight2);
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
llmResp2.setSqlRespMap(sqlWeight2);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
Assert.assertEquals(deduplicationSqlResp.size(), 1);
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
LLMResp llmResp3 = new LLMResp();
Map<String, Double> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, 0.2D);
sqlWeight3.put(sql2, 0.8D);
llmResp3.setSqlWeight(sqlWeight3);
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
llmResp3.setSqlRespMap(sqlWeight3);
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
Assert.assertEquals(deduplicationSqlResp.size(), 2);
}
}