From 3915d90eecade2a1ff63767f55737dc537ece1d6 Mon Sep 17 00:00:00 2001 From: yudong Date: Mon, 26 Aug 2024 10:20:52 +0800 Subject: [PATCH] =?UTF-8?q?(improvement)(chat)=E5=A4=9A=E8=BD=AE=E5=AF=B9?= =?UTF-8?q?=E8=AF=9D=E5=A2=9E=E5=8A=A0rewrittenQuery=E7=9A=84MapInfo=20(#1?= =?UTF-8?q?599)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmark/benchmark.py | 5 +++-- benchmark/requirements.txt | 2 +- .../tencent/supersonic/chat/server/parser/NL2SQLParser.java | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 81dfeea43..68c53c5ac 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -38,9 +38,10 @@ class BatchTest: response = requests.post(url, headers=headers, data=json.dumps(data)) return response.json() - def execute(self, query_text, queryId): + def execute(self, agentId, query_text, queryId): url = self.base_url + 'execute' data = { + 'agentId': agentId, 'queryText': query_text, 'parseId': 1, 'chatId': self.chatId, @@ -75,7 +76,7 @@ def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str): # 捕获异常,防止程序中断 try: parse_resp = batch_test.parse(question) - batch_test.execute(question, parse_resp['data']['queryId']) + batch_test.execute(agentId, question, parse_resp['data']['queryId']) except Exception as e: print('error:', e) traceback.print_exc() diff --git a/benchmark/requirements.txt b/benchmark/requirements.txt index d6f5648d8..12e8040fd 100644 --- a/benchmark/requirements.txt +++ b/benchmark/requirements.txt @@ -1,4 +1,4 @@ -pandas==2.0.3 +pandas==2.2.2 PyJWT==2.8.0 requests==2.28.2 diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 5330a957d..0f3ad22d2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -203,8 +203,10 @@ public class NL2SQLParser implements ChatQueryParser { Response response = chatLanguageModel.generate(prompt.toUserMessage()); String rewrittenQuery = response.content().text(); keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery); - parseContext.setQueryText(rewrittenQuery); + QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); + MapResp rewrittenQueryMapResult = chatLayerService.performMapping(rewrittenQueryNLReq); + parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo()); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); }