mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-01-04 00:12:47 +08:00
(improvement)(chat) Add 'Few-shot Examples' display to the Chat chart. (#589)
This commit is contained in:
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.LLMResponseService;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
@@ -15,38 +16,40 @@ class LLMResponseServiceTest {
|
||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
Map<String, Double> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, 0.2D);
|
||||
sqlWeight.put(sql2, 0.8D);
|
||||
llmResp.setSqlWeight(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
|
||||
Map<String, LLMSqlResp> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||
llmResp.setSqlRespMap(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp2 = new LLMResp();
|
||||
Map<String, Double> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, 0.2D);
|
||||
sqlWeight2.put(sql2, 0.8D);
|
||||
llmResp2.setSqlWeight(sqlWeight2);
|
||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
|
||||
Map<String, LLMSqlResp> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight2.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||
llmResp2.setSqlRespMap(sqlWeight2);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp3 = new LLMResp();
|
||||
Map<String, Double> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, 0.2D);
|
||||
sqlWeight3.put(sql2, 0.8D);
|
||||
llmResp3.setSqlWeight(sqlWeight3);
|
||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
|
||||
Map<String, LLMSqlResp> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, LLMSqlResp.builder().sqlWeight(0.20).build());
|
||||
sqlWeight3.put(sql2, LLMSqlResp.builder().sqlWeight(0.80).build());
|
||||
llmResp3.setSqlRespMap(sqlWeight3);
|
||||
deduplicationSqlResp = llmResponseService.getDeduplicationSqlResp(llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
|
||||
Assert.assertEquals(deduplicationSqlResp.size(), 2);
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user