mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40ea6a9396 | ||
|
|
78d724ea83 | ||
|
|
eadbdc4e30 | ||
|
|
b8831317e9 |
@@ -1,4 +1,4 @@
|
|||||||
name: supersonic ubuntu CI
|
name: supersonic CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
35
.github/workflows/mac-ci.yml
vendored
35
.github/workflows/mac-ci.yml
vendored
@@ -1,35 +0,0 @@
|
|||||||
name: supersonic mac CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: macos-latest # Specify a macOS runner
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Set up JDK 8
|
|
||||||
uses: actions/setup-java@v2
|
|
||||||
with:
|
|
||||||
java-version: '8'
|
|
||||||
distribution: 'adopt'
|
|
||||||
|
|
||||||
- name: Cache Maven packages
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: ~/Library/Caches/Maven # macOS Maven cache path
|
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
|
||||||
restore-keys: ${{ runner.os }}-m2
|
|
||||||
|
|
||||||
- name: Build with Maven
|
|
||||||
run: mvn -B package --file pom.xml
|
|
||||||
|
|
||||||
- name: Test with Maven
|
|
||||||
run: mvn test
|
|
||||||
35
.github/workflows/windows-ci.yml
vendored
35
.github/workflows/windows-ci.yml
vendored
@@ -1,35 +0,0 @@
|
|||||||
name: supersonic windows CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: windows-latest # Specify a Windows runner
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Set up JDK 8
|
|
||||||
uses: actions/setup-java@v2
|
|
||||||
with:
|
|
||||||
java-version: '8'
|
|
||||||
distribution: 'adopt'
|
|
||||||
|
|
||||||
- name: Cache Maven packages
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: ~\.m2 # Windows uses a backslash for paths
|
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
|
||||||
restore-keys: ${{ runner.os }}-m2
|
|
||||||
|
|
||||||
- name: Build with Maven
|
|
||||||
run: mvn -B package --file pom.xml
|
|
||||||
|
|
||||||
- name: Test with Maven
|
|
||||||
run: mvn test
|
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -8,7 +8,6 @@ log/
|
|||||||
*.bin
|
*.bin
|
||||||
*.log
|
*.log
|
||||||
*.tar.gz
|
*.tar.gz
|
||||||
*.zip
|
|
||||||
*.lib
|
*.lib
|
||||||
assembly/runtime/*
|
assembly/runtime/*
|
||||||
**/dist/
|
**/dist/
|
||||||
@@ -18,5 +17,4 @@ assembly/runtime/*
|
|||||||
**/.flattened-pom.xml
|
**/.flattened-pom.xml
|
||||||
chm_db/
|
chm_db/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
/dict
|
/dict
|
||||||
assembly/build/*-SNAPSHOT
|
|
||||||
23
CHANGELOG.md
23
CHANGELOG.md
@@ -4,29 +4,6 @@
|
|||||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||||
compatibility issues with previous versions.
|
compatibility issues with previous versions.
|
||||||
|
|
||||||
## SuperSonic [0.9.2] - 2024-06-01
|
|
||||||
|
|
||||||
### Added
|
|
||||||
- support multiple rounds of dialogue
|
|
||||||
- add term configuration and identification to help LLM learn private domain knowledge
|
|
||||||
- support configuring LLM parameters in the agent
|
|
||||||
- metric market supports searching in natural language
|
|
||||||
|
|
||||||
### Updated
|
|
||||||
- introducing WorkFlow, Mapper, Parser, and Corrector support jump execution
|
|
||||||
- Introducing the concept of Model-Set to simplify Domain management
|
|
||||||
- overall optimization and upgrade of system pages
|
|
||||||
- optimize startup script
|
|
||||||
|
|
||||||
## SuperSonic [0.9.0] - 2024-04-03
|
|
||||||
|
|
||||||
### Added
|
|
||||||
- add tag abstraction and enhance tag marketplace management.
|
|
||||||
- headless-server provides Chat API interface.
|
|
||||||
|
|
||||||
### Updated
|
|
||||||
- migrate chat-core core component to headless-core.
|
|
||||||
|
|
||||||
## SuperSonic [0.8.6] - 2024-02-23
|
## SuperSonic [0.8.6] - 2024-02-23
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
42
README.md
42
README.md
@@ -2,39 +2,31 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
# SuperSonic
|
# SuperSonic (超音数)
|
||||||
|
|
||||||
SuperSonic is the next-generation BI platform that integrates **Chat BI** (powered by LLM) and **Headless BI** (powered by semantic layer) paradigms. This integration ensures that Chat BI has access to the same curated and governed semantic data models as traditional BI. Furthermore, the implementation of both paradigms benefits from the integration:
|
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||||
|
|
||||||
- Chat BI's Text2SQL gets augmented with context-retrieval from semantic models.
|
|
||||||
- Headless BI's query interface gets extended with natural language API.
|
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_ideas.png" height="75%" width="75%" align="center"/>
|
|
||||||
|
|
||||||
SuperSonic provides a **Chat BI interface** that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of metric/dimension/tag, along with their meaning and relationships) through a **Headless BI interface**. Meanwhile, SuperSonic is designed to be extensible and composable, allowing custom implementations to be added and configured with Java SPI.
|
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||||
|
|
||||||
## Motivation
|
## Motivation
|
||||||
|
|
||||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved, leading to a new paradigm in the field of data analytics known as Chat BI. To implement Chat BI, both academia and industry are primarily focused on harnessing the power of LLMs to convert natural language into SQL, commonly referred to as Text2SQL or NL2SQL. While some approaches show promising results, their **reliability** falls short for large-scale real-world applications.
|
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications.
|
||||||
|
|
||||||
Meanwhile, another emerging paradigm called Headless BI, which focuses on constructing unified semantic data models, has garnered significant attention. Headless BI is implemented through a universal semantic layer that exposes consistent data semantics via an open API.
|
From our perspective, the key to filling the real-world gap lies in three aspects:
|
||||||
|
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
||||||
|
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
|
||||||
|
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
|
||||||
|
|
||||||
From our perspective, the integration of Chat BI and Headless BI has the potential to enhance the Text2SQL generation in two dimensions:
|
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of ChatBI, we decide to open source SuperSonic as an extensible framework.
|
||||||
|
|
||||||
1. Incorporate data semantics (such as business terms, column values, etc.) into the prompt, enabling LLM to better understand the semantics and **reduce hallucination**.
|
|
||||||
2. Offload the generation of advanced SQL syntax (such as join, formula, etc.) from LLM to the semantic layer to **reduce complexity**.
|
|
||||||
|
|
||||||
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development we decide to open source SuperSonic as an extensible framework.
|
|
||||||
|
|
||||||
## Out-of-the-box Features
|
## Out-of-the-box Features
|
||||||
|
|
||||||
- Built-in Chat BI interface for *business users* to enter natural language queries
|
- Built-in ChatBI interface for *business users* to enter natural language queries
|
||||||
- Built-in Headless BI interface for *analytics engineers* to build semantic data models
|
- Built-in HeadlessBI interface for *analytics engineers* to build semantic models
|
||||||
- Built-in rule-based semantic parser to improve efficiency in certain scenarios (e.g. demonstration, integration testing)
|
- Built-in GUI for *system administrators* to manage chat agents and third-party plugins
|
||||||
- Built-in support for input auto-completion, multi-turn conversation as well as post-query recommendation
|
- Support input auto-completion as well as query recommendation
|
||||||
- Built-in support for three-level data access control: dataset-level, column-level and row-level
|
- Support multi-turn conversation and history context management
|
||||||
|
- Support four-level permission control: domain-level, model-level, column-level and row-level
|
||||||
|
|
||||||
## Extensible Components
|
## Extensible Components
|
||||||
|
|
||||||
@@ -46,11 +38,11 @@ The high-level architecture and main process flow is as follows:
|
|||||||
|
|
||||||
- **Schema Mapper:** identifies references to schema elements(metrics/dimensions/entities/values) in user queries. It matches the query text against the knowledge base.
|
- **Schema Mapper:** identifies references to schema elements(metrics/dimensions/entities/values) in user queries. It matches the query text against the knowledge base.
|
||||||
|
|
||||||
- **Semantic Parser:** understands user queries and generates semantic query statement. It consists of a combination of rule-based and model-based parsers, each of which deals with specific scenarios.
|
- **Semantic Parser:** understands user queries and extracts semantic information. It consists of a combination of rule-based and model-based parsers, each of which deals with specific scenarios.
|
||||||
|
|
||||||
- **Semantic Corrector:** checks validity of semantic query statement and performs correction and optimization if needed.
|
- **Semantic Corrector:** checks validity of extracted semantic information and performs correction and optimization if needed.
|
||||||
|
|
||||||
- **Semantic Translator:** converts semantic query statement into SQL statement that can be executed against physical data models.
|
- **Semantic Interpreter:** performs execution according to extracted semantic information. It generates SQL statements and executes them against physical data models.
|
||||||
|
|
||||||
- **Chat Plugin:** extends functionality with third-party tools. The LLM is going to select the most suitable one, given all configured plugins with function description and sample questions.
|
- **Chat Plugin:** extends functionality with third-party tools. The LLM is going to select the most suitable one, given all configured plugins with function description and sample questions.
|
||||||
|
|
||||||
|
|||||||
40
README_CN.md
40
README_CN.md
@@ -1,36 +1,28 @@
|
|||||||
# SuperSonic
|
# SuperSonic (超音数)
|
||||||
|
|
||||||
**SuperSonic融合Chat BI(powered by LLM)和Headless BI(powered by 语义层)打造新一代的BI平台**。这种融合确保了Chat BI能够与传统BI一样访问统一化治理的语义数据模型。此外,两种BI新范式都从中获得收益:
|
**SuperSonic融合ChatBI和HeadlessBI打造新一代的数据分析平台**。通过SuperSonic的问答对话界面,用户能够使用自然语言查询数据,系统会选择合适的可视化图表呈现结果。SuperSonic不需要修改或复制数据,只需要在物理数据模型之上构建逻辑语义模型(指标/维度/实体的定义,以及他们的业务含义、相互间关系等),即可开启数据问答体验。与此同时,SuperSonic被设计为可插拔的框架,采用Java SPI机制来扩展定制功能。
|
||||||
|
|
||||||
- Chat BI的Text2SQL生成通过检索语义数据模型得到增强。
|
|
||||||
- Headless BI的查询接口通过支持自然语言API得到拓展。
|
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_ideas.png" height="75%" width="75%" align="center"/>
|
|
||||||
|
|
||||||
通过SuperSonic的问答对话界面,用户能够使用自然语言查询数据,系统会选择合适的可视化图表呈现结果。SuperSonic不需要修改或复制数据,只需要在物理数据模型之上构建逻辑语义模型(定义指标/维度/实体/标签,以及它们的业务含义、相互关系等),即可开启数据问答体验。与此同时,SuperSonic被设计为可插拔的框架,采用Java SPI机制来扩展定制功能。
|
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||||
|
|
||||||
## 项目动机
|
## 项目动机
|
||||||
|
|
||||||
大型语言模型(LLM)如ChatGPT的出现正在重塑信息检索的方式,引领数据分析领域的一种新范式,被称为Chat BI。为了实现Chat BI,学术界和工业界主要关注利用LLM的能力将自然语言转换为SQL,通常称为Text2SQL或NL2SQL。尽管一些方法显示出有希望的结果,但它们在大规模实际应用中的可靠性还不足。
|
大型语言模型(LLMs)如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们的可靠性还达不到生产可用的要求。
|
||||||
|
|
||||||
与此同时,另一种新兴范式被称为Headless BI,它专注于构建统一的语义数据模型,并引起了广泛的关注。Headless BI通过一个通用的语义层来实现,通过开放的API公开一致的数据语义。
|
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
||||||
|
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。
|
||||||
从我们的角度来看,Chat BI和Headless BI的融合有潜力在两个方面增强Text2SQL的能力:
|
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
||||||
|
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
||||||
1. 将数据语义(如业务术语、列值等)纳入提示词中,使LLM能够更好地理解语义,以**减少幻觉**。
|
|
||||||
2. 将高级SQL语法(如连接、公式等)的生成从LLM卸载到语义层,以**减少复杂度**。
|
|
||||||
|
|
||||||
为了验证上述想法,我们开发了SuperSonic项目,并将其应用在实际的内部产品中。与此同时,我们将SuperSonic作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
为了验证上述想法,我们开发了SuperSonic项目,并将其应用在实际的内部产品中。与此同时,我们将SuperSonic作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
||||||
|
|
||||||
## 开箱即用的特性
|
## 开箱即用的特性
|
||||||
|
|
||||||
- 内置Chat BI界面以便*业务用户*输入数据查询。
|
- 内置ChatBI界面以便*业务用户*输入数据查询。
|
||||||
- 内置Headless BI界面以便*分析工程师*构建语义模型。
|
- 内置HeadlessBI界面以便*分析工程师*构建语义模型。
|
||||||
- 内置基于规则的语义解析器,在特定场景(比如DEMO演示、集成测试)可以提升推理效率。
|
- 内置图形用户界面以便*系统管理员*管理第三方插件和对话助理。
|
||||||
- 支持文本输入联想、多轮对话、查询后问题推荐等高级特征。
|
- 支持文本输入的联想和查询问题的推荐。
|
||||||
- 支持三级权限控制:数据集级、列级、行级。
|
- 支持多轮对话,根据语境自动切换上下文。
|
||||||
|
- 支持四级权限控制:主题域级、模型级、列级、行级。
|
||||||
|
|
||||||
## 易于扩展的组件
|
## 易于扩展的组件
|
||||||
|
|
||||||
@@ -42,11 +34,11 @@ SuperSonic的整体架构和主流程如下图所示:
|
|||||||
|
|
||||||
- **模式映射器(Schema Mapper):** 将自然语言文本在知识库中进行匹配,为后续的语义解析提供相关信息。
|
- **模式映射器(Schema Mapper):** 将自然语言文本在知识库中进行匹配,为后续的语义解析提供相关信息。
|
||||||
|
|
||||||
- **语义解析器(Semantic Parser):** 理解用户查询并抽取语义信息,生成语义查询语句S2SQL。
|
- **语义解析器(Semantic Parser):** 理解用户查询并抽取语义信息,其由一组基于规则和基于模型的解析器组成,每个解析器可应对不同的特定场景。
|
||||||
|
|
||||||
- **语义修正器(Semantic Corrector):** 检查语义查询语句的合法性,对不合法的信息做修正和优化处理。
|
- **语义修正器(Semantic Corrector):** 检查语义信息的合法性,对不合法的信息做修正和优化处理。
|
||||||
|
|
||||||
- **语义翻译器(Semantic Translator):** 将语义查询语句翻译成可在物理数据模型上执行的SQL语句。
|
- **语义解释器(Semantic Interpreter):** 根据语义信息生成物理SQL执行查询。
|
||||||
|
|
||||||
- **问答插件(Chat Plugin):** 通过第三方工具扩展功能。给定所有配置的插件及其功能描述和示例问题,大语言模型将选择最合适的插件。
|
- **问答插件(Chat Plugin):** 通过第三方工具扩展功能。给定所有配置的插件及其功能描述和示例问题,大语言模型将选择最合适的插件。
|
||||||
|
|
||||||
|
|||||||
@@ -1,98 +1,72 @@
|
|||||||
@echo off
|
@echo off
|
||||||
setlocal enabledelayedexpansion
|
setlocal
|
||||||
chcp 65001
|
chcp 65001
|
||||||
|
set "sbinDir=%~dp0"
|
||||||
call supersonic-common.bat %*
|
set "baseDir=%~dp0.."
|
||||||
|
set "buildDir=%baseDir%\build"
|
||||||
|
set "runtimeDir=%baseDir%\..\runtime"
|
||||||
|
set "pip_path=pip3"
|
||||||
set "service=%~1"
|
set "service=%~1"
|
||||||
|
|
||||||
cd %projectDir%
|
|
||||||
if "%service%"=="" (
|
rem 1. build backend java modules
|
||||||
set service=%standalone_service%
|
del /q "%buildDir%\*.tar.gz" 2>NUL
|
||||||
|
call mvn -f "%baseDir%\..\pom.xml" clean package -DskipTests
|
||||||
|
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
ECHO Failed to build backend Java modules.
|
||||||
|
EXIT /B 1
|
||||||
)
|
)
|
||||||
|
|
||||||
call mvn help:evaluate -Dexpression=project.version > temp.txt
|
rem 2. move package to build
|
||||||
for /f "delims=" %%i in (temp.txt) do (
|
echo f|xcopy "%baseDir%\..\launchers\standalone\target\*.tar.gz" "%buildDir%\supersonic-standalone.tar.gz"
|
||||||
set line=%%i
|
|
||||||
if not "!line:~0,1!"=="[" (
|
|
||||||
set MVN_VERSION=!line!
|
|
||||||
)
|
|
||||||
)
|
|
||||||
del temp.txt
|
|
||||||
cd %baseDir%
|
|
||||||
|
|
||||||
|
rem 3. build frontend webapp
|
||||||
|
cd "%baseDir%\..\webapp"
|
||||||
|
call start-fe-prod.bat
|
||||||
|
copy /y "%baseDir%\..\webapp\supersonic-webapp.tar.gz" "%buildDir%\"
|
||||||
|
|
||||||
if "%service%"=="%pyllm_service%" (
|
IF ERRORLEVEL 1 (
|
||||||
echo start installing python modules required by supersonic-pyllm: %pip_path%
|
ECHO Failed to build frontend webapp.
|
||||||
%pip_path% install -r %projectDir%\headless\python\requirements.txt"
|
EXIT /B 1
|
||||||
echo install python modules success
|
|
||||||
goto :EOF
|
|
||||||
) else if "%service%"=="webapp" (
|
|
||||||
call :buildWebapp
|
|
||||||
tar xvf supersonic-webapp.tar.gz
|
|
||||||
move /y supersonic-webapp webapp
|
|
||||||
move /y webapp %projectDir%\launchers\%STANDALONE_SERVICE%\target\classes
|
|
||||||
goto :EOF
|
|
||||||
) else (
|
|
||||||
call :buildJavaService
|
|
||||||
call :buildWebapp
|
|
||||||
call :packageRelease
|
|
||||||
goto :EOF
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rem 4. copy webapp to java classpath
|
||||||
|
cd "%buildDir%"
|
||||||
|
tar -zxvf supersonic-webapp.tar.gz
|
||||||
|
move supersonic-webapp webapp
|
||||||
|
move webapp ..\..\launchers\standalone\target\classes
|
||||||
|
|
||||||
:buildJavaService
|
rem 5. build backend python modules
|
||||||
set "model_name=%service%"
|
if "%service%"=="pyllm" (
|
||||||
echo "starting building supersonic-%model_name% service"
|
echo "start installing python modules with pip: ${pip_path}"
|
||||||
call mvn -f %projectDir%\launchers\%model_name% clean package -DskipTests
|
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||||
IF ERRORLEVEL 1 (
|
%pip_path% install -r %requirementPath%
|
||||||
ECHO Failed to build backend Java modules.
|
echo "install python modules success"
|
||||||
EXIT /B 1
|
)
|
||||||
)
|
|
||||||
copy /y %projectDir%\launchers\%model_name%\target\*.tar.gz %buildDir%\
|
|
||||||
echo "finished building supersonic-%model_name% service"
|
|
||||||
goto :EOF
|
|
||||||
|
|
||||||
|
call :BUILD_RUNTIME
|
||||||
|
|
||||||
:buildWebapp
|
:BUILD_RUNTIME
|
||||||
echo "starting building supersonic webapp"
|
rem 6. reset runtime
|
||||||
cd %projectDir%\webapp
|
IF EXIST "%runtimeDir%" (
|
||||||
call start-fe-prod.bat
|
echo begin to delete dir : %runtimeDir%
|
||||||
copy /y supersonic-webapp.tar.gz %buildDir%\
|
rd /s /q "%runtimeDir%"
|
||||||
rem check build result
|
) ELSE (
|
||||||
IF ERRORLEVEL 1 (
|
echo %runtimeDir% does not exist, create directly
|
||||||
ECHO Failed to build frontend webapp.
|
)
|
||||||
EXIT /B 1
|
mkdir "%runtimeDir%"
|
||||||
)
|
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||||
echo "finished building supersonic webapp"
|
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||||
goto :EOF
|
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||||
|
)
|
||||||
|
|
||||||
|
rem 7. copy webapp to runtime
|
||||||
:packageRelease
|
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||||
set "model_name=%service%"
|
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
set "release_dir=supersonic-%model_name%-%MVN_VERSION%"
|
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
set "service_name=launchers-%model_name%-%MVN_VERSION%"
|
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
echo "starting packaging supersonic release"
|
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
cd %buildDir%
|
rd /s /q "%buildDir%\supersonic-webapp"
|
||||||
if exist %release_dir% rmdir /s /q %release_dir%
|
|
||||||
if exist %release_dir%.zip del %release_dir%.zip
|
|
||||||
mkdir %release_dir%
|
|
||||||
rem package webapp
|
|
||||||
tar xvf supersonic-webapp.tar.gz
|
|
||||||
move /y supersonic-webapp webapp
|
|
||||||
echo {"env": ""} > webapp\supersonic.config.json
|
|
||||||
move /y webapp %release_dir%
|
|
||||||
rem package java service
|
|
||||||
tar xvf %service_name%-bin.tar.gz
|
|
||||||
for /d %%D in ("%service_name%\*") do (
|
|
||||||
move "%%D" "%release_dir%"
|
|
||||||
)
|
|
||||||
rem generate zip file
|
|
||||||
powershell Compress-Archive -Path %release_dir% -DestinationPath %release_dir%.zip
|
|
||||||
del %service_name%-bin.tar.gz
|
|
||||||
del supersonic-webapp.tar.gz
|
|
||||||
rmdir /s /q %service_name%
|
|
||||||
echo "finished packaging supersonic release"
|
|
||||||
goto :EOF
|
|
||||||
|
|
||||||
endlocal
|
endlocal
|
||||||
@@ -1,80 +1,58 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -x
|
||||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||||
chmod +x $sbinDir/supersonic-common.sh
|
chmod +x $sbinDir/supersonic-common.sh
|
||||||
source $sbinDir/supersonic-common.sh
|
source $sbinDir/supersonic-common.sh
|
||||||
cd $projectDir
|
|
||||||
MVN_VERSION=$(mvn help:evaluate -Dexpression=project.version | grep -e '^[^\[]')
|
|
||||||
|
|
||||||
cd $baseDir
|
cd $baseDir
|
||||||
|
|
||||||
service=$1
|
service=$1
|
||||||
if [ -z "$service" ]; then
|
#1. build backend java modules
|
||||||
service=${STANDALONE_SERVICE}
|
rm -fr ${buildDir}/*.tar.gz
|
||||||
|
rm -fr dist
|
||||||
|
set +x
|
||||||
|
mvn -f $baseDir/../ clean package -DskipTests
|
||||||
|
# check build result
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed to build backend Java modules."
|
||||||
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
function buildJavaService {
|
#2. move package to build
|
||||||
model_name=$1
|
cp $baseDir/../launchers/headless/target/*.tar.gz ${buildDir}/supersonic-headless.tar.gz
|
||||||
echo "starting building supersonic-${model_name} service"
|
cp $baseDir/../launchers/chat/target/*.tar.gz ${buildDir}/supersonic-chat.tar.gz
|
||||||
mvn -f $projectDir clean package -DskipTests
|
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic-standalone.tar.gz
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed to build backend Java modules."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cp $projectDir/launchers/${model_name}/target/*.tar.gz ${buildDir}/
|
|
||||||
echo "finished building supersonic-${model_name} service"
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildWebapp {
|
#3. build frontend webapp
|
||||||
echo "starting building supersonic webapp"
|
chmod +x $baseDir/../webapp/start-fe-prod.sh
|
||||||
chmod +x $projectDir/webapp/start-fe-prod.sh
|
cd ../webapp
|
||||||
cd $projectDir/webapp
|
sh ./start-fe-prod.sh
|
||||||
sh ./start-fe-prod.sh
|
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
|
||||||
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
|
|
||||||
# check build result
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed to build frontend webapp."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "finished building supersonic webapp"
|
|
||||||
}
|
|
||||||
|
|
||||||
function packageRelease {
|
# check build result
|
||||||
model_name=$1
|
if [ $? -ne 0 ]; then
|
||||||
release_dir=supersonic-${model_name}-${MVN_VERSION}
|
echo "Failed to build frontend webapp."
|
||||||
service_name=launchers-${model_name}-${MVN_VERSION}
|
exit 1
|
||||||
echo "starting packaging supersonic release"
|
fi
|
||||||
cd $buildDir
|
#4. copy webapp to java classpath
|
||||||
mkdir $release_dir
|
cd $buildDir
|
||||||
# package webapp
|
tar xvf supersonic-webapp.tar.gz
|
||||||
tar xvf supersonic-webapp.tar.gz
|
mv supersonic-webapp webapp
|
||||||
mv supersonic-webapp webapp
|
cp -fr webapp ../../launchers/headless/target/classes
|
||||||
json='{"env": "''"}'
|
cp -fr webapp ../../launchers/chat/target/classes
|
||||||
echo $json > webapp/supersonic.config.json
|
cp -fr webapp ../../launchers/standalone/target/classes
|
||||||
mv webapp $release_dir/
|
rm -fr ${buildDir}/webapp
|
||||||
# package java service
|
|
||||||
tar xvf $service_name-bin.tar.gz
|
|
||||||
mv $service_name/* $release_dir/
|
|
||||||
# generate zip file
|
|
||||||
zip -r $release_dir.zip $release_dir
|
|
||||||
# delete intermediate files
|
|
||||||
rm supersonic-webapp.tar.gz $service_name-bin.tar.gz
|
|
||||||
rm -rf webapp $service_name $release_dir
|
|
||||||
echo "finished packaging supersonic release"
|
|
||||||
}
|
|
||||||
|
|
||||||
#1. build backend services
|
#5. build backend python modules
|
||||||
if [ "$service" == $PYLLM_SERVICE ]; then
|
if [ "$service" == "pyllm" ]; then
|
||||||
echo "start installing python modules required by supersonic-pyllm: ${pip_path}"
|
echo "start installing python modules with pip: ${pip_path}"
|
||||||
requirementPath=$projectDir/headless/python/requirements.txt
|
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||||
${pip_path} install -r ${requirementPath}
|
${pip_path} install -r ${requirementPath}
|
||||||
echo "install python modules success"
|
echo "install python modules success"
|
||||||
elif [ "$service" == "webapp" ]; then
|
fi
|
||||||
buildWebapp
|
|
||||||
target_path=$projectDir/launchers/$STANDALONE_SERVICE/target/classes
|
#6. reset runtime
|
||||||
tar xvf $projectDir/webapp/supersonic-webapp.tar.gz -C $target_path
|
rm -fr $runtimeDir/supersonic*
|
||||||
mv $target_path/supersonic-webapp $target_path/webapp
|
moveAllToRuntime
|
||||||
else
|
setEnvToWeb chat
|
||||||
buildJavaService $service
|
setEnvToWeb headless
|
||||||
buildWebapp
|
|
||||||
packageRelease $service
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
set "sbinDir=%~dp0"
|
|
||||||
set "baseDir=%~dp0.."
|
|
||||||
set "buildDir=%baseDir%\build"
|
|
||||||
set "main_class=com.tencent.supersonic.StandaloneLauncher"
|
|
||||||
set "python_path=python"
|
|
||||||
set "pip_path=pip3"
|
|
||||||
set "standalone_service=standalone"
|
|
||||||
set "pyllm_service=pyllm"
|
|
||||||
set "projectDir=%baseDir%\.."
|
|
||||||
@@ -6,19 +6,105 @@ pip_path=${PIP_PATH:-"pip3"}
|
|||||||
|
|
||||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||||
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
||||||
runtimeDir=$baseDir/runtime
|
runtimeDir=$baseDir/../runtime
|
||||||
buildDir=$baseDir/build
|
buildDir=$baseDir/build
|
||||||
projectDir=$baseDir/..
|
|
||||||
|
|
||||||
readonly CHAT_APP_NAME="supersonic_chat"
|
readonly CHAT_APP_NAME="supersonic_chat"
|
||||||
readonly HEADLESS_APP_NAME="supersonic_headless"
|
readonly HEADLESS_APP_NAME="supersonic_headless"
|
||||||
readonly PYLLM_APP_NAME="supersonic_pyllm"
|
readonly PYLLM_APP_NAME="supersonic_pyllm"
|
||||||
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
||||||
|
|
||||||
readonly CHAT_SERVICE="chat"
|
readonly CHAT_SERVICE="chat"
|
||||||
readonly HEADLESS_SERVICE="headless"
|
readonly HEADLESS_SERVICE="headless"
|
||||||
readonly PYLLM_SERVICE="pyllm"
|
readonly PYLLM_SERVICE="pyllm"
|
||||||
readonly STANDALONE_SERVICE="standalone"
|
readonly STANDALONE_SERVICE="standalone"
|
||||||
|
|
||||||
readonly PYLLM_HOST="127.0.0.1"
|
readonly PYLLM_HOST="127.0.0.1"
|
||||||
readonly PYLLM_PORT="9092"
|
readonly PYLLM_PORT="9092"
|
||||||
|
|
||||||
|
function setEnvToWeb {
|
||||||
|
model_name=$1
|
||||||
|
json='{"env": "'$model_name'"}'
|
||||||
|
echo $json > ${runtimeDir}/supersonic-${model_name}/webapp/supersonic.config.json
|
||||||
|
echo $json > $baseDir/../launchers/${model_name}/target/classes/webapp/supersonic.config.json
|
||||||
|
}
|
||||||
|
|
||||||
|
function moveToRuntime {
|
||||||
|
model_name=$1
|
||||||
|
file="${buildDir}/supersonic-${model_name}.tar.gz"
|
||||||
|
if [ -f "$file" ]; then
|
||||||
|
tar -zxvf "$file" -C ${runtimeDir}
|
||||||
|
mv ${runtimeDir}/launchers-${model_name}-* ${runtimeDir}/supersonic-${model_name}
|
||||||
|
mkdir -p ${runtimeDir}/supersonic-${model_name}/webapp
|
||||||
|
cp -fr ${buildDir}/webapp/* ${runtimeDir}/supersonic-${model_name}/webapp
|
||||||
|
else
|
||||||
|
echo "File $file does not exist. Skipping the move to runtime."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
function moveAllToRuntime {
|
||||||
|
mkdir -p ${runtimeDir}
|
||||||
|
tar xvf ${buildDir}/supersonic-webapp.tar.gz -C ${buildDir}
|
||||||
|
mv ${buildDir}/supersonic-webapp ${buildDir}/webapp
|
||||||
|
|
||||||
|
moveToRuntime chat
|
||||||
|
moveToRuntime headless
|
||||||
|
moveToRuntime standalone
|
||||||
|
rm -fr ${buildDir}/webapp
|
||||||
|
}
|
||||||
|
|
||||||
|
# run java service
|
||||||
|
function runJavaService {
|
||||||
|
javaRunDir=${runtimeDir}/supersonic-${model_name}
|
||||||
|
local_app_name=$1
|
||||||
|
libDir=$javaRunDir/lib
|
||||||
|
confDir=$javaRunDir/conf
|
||||||
|
|
||||||
|
CLASSPATH=""
|
||||||
|
CLASSPATH=$CLASSPATH:$confDir
|
||||||
|
|
||||||
|
for jarPath in $libDir/*.jar; do
|
||||||
|
CLASSPATH=$CLASSPATH:$jarPath
|
||||||
|
done
|
||||||
|
|
||||||
|
export CLASSPATH
|
||||||
|
export LANG="zh_CN.UTF-8"
|
||||||
|
|
||||||
|
cd $javaRunDir
|
||||||
|
if [[ "$JAVA_HOME" == "" ]]; then
|
||||||
|
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||||
|
fi
|
||||||
|
export PATH=$JAVA_HOME/bin:$PATH
|
||||||
|
command="-Dfile.encoding="UTF-8" -Duser.language="Zh" -Duser.region="CN" -Duser.timezone="GMT+08" -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m "$main_class
|
||||||
|
|
||||||
|
mkdir -p $javaRunDir/logs
|
||||||
|
if [[ "$is_test" == "true" ]]; then
|
||||||
|
java -Dspring.profiles.active="dev" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
else
|
||||||
|
java $command $javaRunDir >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# run python service
|
||||||
|
function runPythonService {
|
||||||
|
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
|
||||||
|
cd $pythonRunDir
|
||||||
|
nohup ${python_path} supersonic_pyllm.py > $pythonRunDir/pyllm.log 2>&1 &
|
||||||
|
# add health check
|
||||||
|
for i in {1..10}
|
||||||
|
do
|
||||||
|
echo "pyllm health check attempt $i..."
|
||||||
|
response=$(curl -s http://${PYLLM_HOST}:${PYLLM_PORT}/health)
|
||||||
|
echo "pyllm health check response: $response"
|
||||||
|
status_ok="Healthy"
|
||||||
|
if [[ $response == *$status_ok* ]] ; then
|
||||||
|
echo "pyllm Health check passed."
|
||||||
|
break
|
||||||
|
else
|
||||||
|
if [ "$i" -eq 10 ]; then
|
||||||
|
echo "pyllm Health check failed after 10 attempts."
|
||||||
|
echo "May still downloading model files. Please check pyllm.log in runtime directory."
|
||||||
|
fi
|
||||||
|
echo "Retrying after 5 seconds..."
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,102 +1,118 @@
|
|||||||
@echo off
|
@echo off
|
||||||
setlocal
|
setlocal
|
||||||
chcp 65001
|
chcp 65001
|
||||||
|
set "sbinDir=%~dp0"
|
||||||
|
set "baseDir=%~dp0.."
|
||||||
|
set "runtimeDir=%baseDir%\..\runtime"
|
||||||
|
set "buildDir=%baseDir%\build"
|
||||||
|
set "main_class=com.tencent.supersonic.StandaloneLauncher"
|
||||||
|
set "python_path=python"
|
||||||
|
set "pip_path=pip3"
|
||||||
|
set "standalone_service=standalone"
|
||||||
|
set "pyllm_service=pyllm"
|
||||||
|
|
||||||
call supersonic-common.bat %*
|
set "javaRunDir=%runtimeDir%\supersonic-standalone"
|
||||||
call %sbinDir%/../conf/supersonic-env.bat %*
|
set "pythonRunDir=%runtimeDir%\supersonic-standalone\pyllm"
|
||||||
|
|
||||||
set "command=%~1"
|
set "command=%~1"
|
||||||
set "service=%~2"
|
set "service=%~2"
|
||||||
|
|
||||||
if "%service%"=="" (
|
if "%service%"=="" (
|
||||||
set "service=%standalone_service%"
|
set "service=%standalone_service%"
|
||||||
)
|
)
|
||||||
set "model_name=%service%"
|
|
||||||
IF "%service%"=="pyllm" (
|
IF "%service%"=="pyllm" (
|
||||||
set "llmProxy=PythonLLMProxy"
|
SET "llmProxy=PythonLLMProxy"
|
||||||
set "model_name=%standalone_service%"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cd %baseDir%
|
call :BUILD_RUNTIME
|
||||||
|
|
||||||
if "%command%"=="restart" (
|
if "%command%"=="restart" (
|
||||||
call :stop
|
call :STOP
|
||||||
call :start
|
call :START
|
||||||
goto :EOF
|
goto :EOF
|
||||||
) else if "%command%"=="start" (
|
) else if "%command%"=="start" (
|
||||||
call :start
|
call :START
|
||||||
goto :EOF
|
goto :EOF
|
||||||
) else if "%command%"=="stop" (
|
) else if "%command%"=="stop" (
|
||||||
call :stop
|
call :STOP
|
||||||
goto :EOF
|
goto :EOF
|
||||||
) else if "%command%"=="reload" (
|
) else if "%command%"=="reload" (
|
||||||
call :reloadExamples
|
call :RELOAD_EXAMPLE
|
||||||
goto :EOF
|
goto :EOF
|
||||||
) else (
|
) else (
|
||||||
echo "Use command {start|stop|restart} to run."
|
echo "Use command {start|stop|restart} to run."
|
||||||
goto :EOF
|
goto :EOF
|
||||||
)
|
)
|
||||||
|
|
||||||
|
:START
|
||||||
: start
|
if "%service%"=="%pyllm_service%" (
|
||||||
if "%service%"=="%pyllm_service%" (
|
call :START_PYTHON
|
||||||
call :runPythonService
|
call :START_JAVA
|
||||||
call :runJavaService
|
|
||||||
goto :EOF
|
goto :EOF
|
||||||
)
|
)
|
||||||
call :runJavaService
|
call :START_JAVA
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:STOP
|
||||||
|
call :STOP_PYTHON
|
||||||
|
call :STOP_JAVA
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:START_PYTHON
|
||||||
|
echo 'python service starting, see logs in pyllm/pyllm.log'
|
||||||
|
cd "%pythonRunDir%"
|
||||||
|
start /B %python_path% supersonic_pyllm.py > %pythonRunDir%\pyllm.log 2>&1
|
||||||
|
timeout /t 10 >nul
|
||||||
|
echo 'python service started'
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
:START_JAVA
|
||||||
: stop
|
echo 'java service starting, see logs in logs/'
|
||||||
call :stopPythonService
|
cd "%javaRunDir%"
|
||||||
call :stopJavaService
|
if not exist "%runtimeDir%\supersonic-standalone\logs" mkdir "%runtimeDir%\supersonic-standalone\logs"
|
||||||
goto :EOF
|
set "libDir=%runtimeDir%\supersonic-standalone\lib"
|
||||||
|
set "confDir=%runtimeDir%\supersonic-standalone\conf"
|
||||||
|
set "webDir=%runtimeDir%\supersonic-standalone\webapp"
|
||||||
: reloadExamples
|
set "classpath=%confDir%;%webDir%;%libDir%\*"
|
||||||
set "pythonRunDir=%baseDir%\pyllm"
|
|
||||||
cd "%pythonRunDir%\sql"
|
|
||||||
start %python_path% examples_reload_run.py
|
|
||||||
goto :EOF
|
|
||||||
|
|
||||||
|
|
||||||
: runJavaService
|
|
||||||
echo 'java service starting, see logs in logs/'
|
|
||||||
set "libDir=%baseDir%\lib"
|
|
||||||
set "confDir=%baseDir%\conf"
|
|
||||||
set "webDir=%baseDir%\webapp"
|
|
||||||
set "logDir=%baseDir%\logs"
|
|
||||||
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
|
||||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||||
if not exist %logDir% mkdir %logDir%
|
|
||||||
start /B java %java-command% >nul 2>&1
|
start /B java %java-command% >nul 2>&1
|
||||||
timeout /t 10 >nul
|
timeout /t 10 >nul
|
||||||
echo 'java service started'
|
echo 'java service started'
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
:STOP_PYTHON
|
||||||
: runPythonService
|
|
||||||
echo 'python service starting, see logs in pyllm\pyllm.log'
|
|
||||||
set "pythonRunDir=%baseDir%\pyllm"
|
|
||||||
start /B %python_path% %pythonRunDir%\supersonic_pyllm.py > %pythonRunDir%\pyllm.log 2>&1
|
|
||||||
timeout /t 10 >nul
|
|
||||||
echo 'python service started'
|
|
||||||
goto :EOF
|
|
||||||
|
|
||||||
|
|
||||||
: stopPythonService
|
|
||||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
|
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
|
||||||
taskkill /PID %%i /F
|
taskkill /PID %%i /F
|
||||||
echo "python service (PID = %%i) is killed."
|
echo "python service (PID = %%i) is killed."
|
||||||
)
|
)
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
:STOP_JAVA
|
||||||
: stopJavaService
|
|
||||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||||
taskkill /PID %%i /F
|
taskkill /PID %%i /F
|
||||||
echo "java service (PID = %%i) is killed."
|
echo "java service (PID = %%i) is killed."
|
||||||
)
|
)
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
endlocal
|
:RELOAD_EXAMPLE
|
||||||
|
cd "%runtimeDir%\supersonic-standalone\pyllm\sql"
|
||||||
|
start %python_path% examples_reload_run.py
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:BUILD_RUNTIME
|
||||||
|
rem 6. reset runtime
|
||||||
|
if exist "%runtimeDir%" goto :EOF
|
||||||
|
mkdir "%runtimeDir%"
|
||||||
|
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||||
|
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||||
|
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||||
|
)
|
||||||
|
|
||||||
|
rem 7. copy webapp to runtime
|
||||||
|
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
|
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
|
rd /s /q "%buildDir%\supersonic-webapp"
|
||||||
@@ -1,11 +1,16 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -x
|
||||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||||
|
chmod +x $sbinDir/supersonic-common.sh
|
||||||
source $sbinDir/supersonic-common.sh
|
source $sbinDir/supersonic-common.sh
|
||||||
|
|
||||||
set -a
|
# 1.init environment parameters
|
||||||
source $sbinDir/../conf/supersonic-env.sh
|
if [ ! -d "$runtimeDir" ]; then
|
||||||
set +a
|
echo "the runtime dir does not exist move all to runtime"
|
||||||
|
moveAllToRuntime
|
||||||
|
fi
|
||||||
|
set +x
|
||||||
|
|
||||||
command=$1
|
command=$1
|
||||||
service=$2
|
service=$2
|
||||||
@@ -13,93 +18,44 @@ if [ -z "$service" ]; then
|
|||||||
service=${STANDALONE_SERVICE}
|
service=${STANDALONE_SERVICE}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
app_name=$STANDALONE_APP_NAME
|
||||||
|
main_class="com.tencent.supersonic.StandaloneLauncher"
|
||||||
model_name=$service
|
model_name=$service
|
||||||
|
|
||||||
if [ "$service" == "pyllm" ]; then
|
if [ "$service" == "pyllm" ]; then
|
||||||
model_name=${STANDALONE_SERVICE}
|
model_name=${STANDALONE_SERVICE}
|
||||||
export llmProxy=PythonLLMProxy
|
export llmProxy=PythonLLMProxy
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd $baseDir
|
cd $baseDir
|
||||||
|
|
||||||
|
# 2.set main class
|
||||||
function setMainClass {
|
function setMainClass {
|
||||||
if [ "$service" == $CHAT_SERVICE ]; then
|
if [ "$service" == $CHAT_SERVICE ]; then
|
||||||
main_class="com.tencent.supersonic.ChatLauncher"
|
main_class="com.tencent.supersonic.ChatLauncher"
|
||||||
elif [ "$service" == $HEADLESS_SERVICE ]; then
|
elif [ "$service" == $HEADLESS_SERVICE ]; then
|
||||||
main_class="com.tencent.supersonic.HeadlessLauncher"
|
main_class="com.tencent.supersonic.HeadlessLauncher"
|
||||||
else
|
|
||||||
main_class="com.tencent.supersonic.StandaloneLauncher"
|
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
setMainClass
|
||||||
|
# 3.set app name
|
||||||
function setAppName {
|
function setAppName {
|
||||||
if [ "$service" == $CHAT_SERVICE ]; then
|
if [ "$service" == $CHAT_SERVICE ]; then
|
||||||
app_name=$CHAT_APP_NAME
|
app_name=$CHAT_APP_NAME
|
||||||
elif [ "$service" == $HEADLESS_SERVICE ]; then
|
elif [ "$service" == $HEADLESS_SERVICE ]; then
|
||||||
app_name=$HEADLESS_APP_NAME
|
app_name=$HEADLESS_APP_NAME
|
||||||
else
|
elif [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
app_name=$STANDALONE_APP_NAME
|
app_name=$PYLLM_APP_NAME
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
setAppName
|
||||||
|
|
||||||
function reloadExamples {
|
function reloadExamples {
|
||||||
cd $baseDir/pyllm/sql
|
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
|
||||||
|
cd $pythonRunDir/sql
|
||||||
${python_path} examples_reload_run.py
|
${python_path} examples_reload_run.py
|
||||||
}
|
}
|
||||||
|
|
||||||
function runJavaService {
|
|
||||||
javaRunDir=$baseDir
|
|
||||||
local_app_name=$1
|
|
||||||
libDir=$baseDir/lib
|
|
||||||
confDir=$baseDir/conf
|
|
||||||
|
|
||||||
CLASSPATH=""
|
|
||||||
CLASSPATH=$CLASSPATH:$confDir
|
|
||||||
|
|
||||||
for jarPath in $libDir/*.jar; do
|
|
||||||
CLASSPATH=$CLASSPATH:$jarPath
|
|
||||||
done
|
|
||||||
|
|
||||||
export CLASSPATH
|
|
||||||
export LANG="zh_CN.UTF-8"
|
|
||||||
|
|
||||||
cd $javaRunDir
|
|
||||||
if [[ "$JAVA_HOME" == "" ]]; then
|
|
||||||
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
|
||||||
fi
|
|
||||||
export PATH=$JAVA_HOME/bin:$PATH
|
|
||||||
command="-Dfile.encoding="UTF-8" -Duser.language="Zh" -Duser.region="CN" -Duser.timezone="GMT+08" -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m "$main_class
|
|
||||||
|
|
||||||
mkdir -p $javaRunDir/logs
|
|
||||||
if [[ "$is_test" == "true" ]]; then
|
|
||||||
java -Dspring.profiles.active="dev" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
|
||||||
else
|
|
||||||
java $command $javaRunDir >/dev/null 2>$javaRunDir/logs/error.log &
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
function runPythonService {
|
|
||||||
pythonRunDir=$baseDir/pyllm
|
|
||||||
cd $pythonRunDir
|
|
||||||
nohup ${python_path} supersonic_pyllm.py > $pythonRunDir/pyllm.log 2>&1 &
|
|
||||||
# add health check
|
|
||||||
for i in {1..10}
|
|
||||||
do
|
|
||||||
echo "pyllm health check attempt $i..."
|
|
||||||
response=$(curl -s http://${PYLLM_HOST}:${PYLLM_PORT}/health)
|
|
||||||
echo "pyllm health check response: $response"
|
|
||||||
status_ok="Healthy"
|
|
||||||
if [[ $response == *$status_ok* ]] ; then
|
|
||||||
echo "pyllm Health check passed."
|
|
||||||
break
|
|
||||||
else
|
|
||||||
if [ "$i" -eq 10 ]; then
|
|
||||||
echo "pyllm Health check failed after 10 attempts."
|
|
||||||
echo "May still downloading model files. Please check pyllm.log in runtime directory."
|
|
||||||
fi
|
|
||||||
echo "Retrying after 5 seconds..."
|
|
||||||
sleep 5
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
}
|
|
||||||
|
|
||||||
function start()
|
function start()
|
||||||
{
|
{
|
||||||
@@ -137,16 +93,18 @@ function reload()
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
setMainClass
|
# 4. execute command operation
|
||||||
setAppName
|
|
||||||
case "$command" in
|
case "$command" in
|
||||||
start)
|
start)
|
||||||
if [ "$service" == $PYLLM_SERVICE ]; then
|
if [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
echo "Starting $PYLLM_APP_NAME"
|
echo "Starting $app_name"
|
||||||
start $PYLLM_APP_NAME
|
start $app_name
|
||||||
|
echo "Starting $STANDALONE_APP_NAME"
|
||||||
|
start $STANDALONE_APP_NAME
|
||||||
|
else
|
||||||
|
echo "Starting $app_name"
|
||||||
|
start $app_name
|
||||||
fi
|
fi
|
||||||
echo "Starting ${app_name}"
|
|
||||||
start ${app_name}
|
|
||||||
echo "Start success"
|
echo "Start success"
|
||||||
;;
|
;;
|
||||||
stop)
|
stop)
|
||||||
@@ -163,15 +121,20 @@ case "$command" in
|
|||||||
;;
|
;;
|
||||||
restart)
|
restart)
|
||||||
if [ "$service" == $PYLLM_SERVICE ]; then
|
if [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
echo "Stopping $PYLLM_APP_NAME"
|
echo "Stopping ${app_name}"
|
||||||
stop $PYLLM_APP_NAME
|
stop ${app_name}
|
||||||
echo "Starting $PYLLM_APP_NAME"
|
echo "Stopping ${STANDALONE_APP_NAME}"
|
||||||
start $PYLLM_APP_NAME
|
stop $STANDALONE_APP_NAME
|
||||||
|
echo "Starting ${app_name}"
|
||||||
|
start ${app_name}
|
||||||
|
echo "Starting ${STANDALONE_APP_NAME}"
|
||||||
|
start $STANDALONE_APP_NAME
|
||||||
|
else
|
||||||
|
echo "Stopping ${app_name}"
|
||||||
|
stop ${app_name}
|
||||||
|
echo "Starting ${app_name}"
|
||||||
|
start ${app_name}
|
||||||
fi
|
fi
|
||||||
echo "Stopping ${app_name}"
|
|
||||||
stop ${app_name}
|
|
||||||
echo "Starting ${app_name}"
|
|
||||||
start ${app_name}
|
|
||||||
echo "Restart success"
|
echo "Restart success"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
|
|||||||
@@ -21,17 +21,11 @@
|
|||||||
</includes>
|
</includes>
|
||||||
</fileSet>
|
</fileSet>
|
||||||
<fileSet>
|
<fileSet>
|
||||||
<directory>${project.basedir}/../../headless/python</directory>
|
<directory>${project.basedir}/../../chat/python</directory>
|
||||||
<outputDirectory>pyllm</outputDirectory>
|
<outputDirectory>pyllm</outputDirectory>
|
||||||
<fileMode>0777</fileMode>
|
<fileMode>0777</fileMode>
|
||||||
<directoryMode>0755</directoryMode>
|
<directoryMode>0755</directoryMode>
|
||||||
</fileSet>
|
</fileSet>
|
||||||
<fileSet>
|
|
||||||
<directory>${project.basedir}/../../assembly/bin</directory>
|
|
||||||
<outputDirectory>bin</outputDirectory>
|
|
||||||
<fileMode>0777</fileMode>
|
|
||||||
<directoryMode>0755</directoryMode>
|
|
||||||
</fileSet>
|
|
||||||
</fileSets>
|
</fileSets>
|
||||||
|
|
||||||
<dependencySets>
|
<dependencySets>
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.auth.api.authentication.utils;
|
|||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
|
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
|
||||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -20,10 +20,10 @@ public final class UserHolder {
|
|||||||
|
|
||||||
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
|
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
|
||||||
User user = REPO.findUser(request, response);
|
User user = REPO.findUser(request, response);
|
||||||
SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class);
|
SysParameterService sysParameterService = ContextUtils.getBean(SysParameterService.class);
|
||||||
SystemConfig systemConfig = sysParameterService.getSystemConfig();
|
SysParameter sysParameter = sysParameterService.getSysParameter();
|
||||||
if (!CollectionUtils.isEmpty(systemConfig.getAdmins())
|
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
|
||||||
&& systemConfig.getAdmins().contains(user.getName())) {
|
&& sysParameter.getAdmins().contains(user.getName())) {
|
||||||
user.setIsAdmin(1);
|
user.setIsAdmin(1);
|
||||||
}
|
}
|
||||||
return user;
|
return user;
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||||
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
|
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
@@ -18,9 +18,9 @@ import java.util.Set;
|
|||||||
@Service
|
@Service
|
||||||
public class UserServiceImpl implements UserService {
|
public class UserServiceImpl implements UserService {
|
||||||
|
|
||||||
private SystemConfigService sysParameterService;
|
private SysParameterService sysParameterService;
|
||||||
|
|
||||||
public UserServiceImpl(SystemConfigService sysParameterService) {
|
public UserServiceImpl(SysParameterService sysParameterService) {
|
||||||
this.sysParameterService = sysParameterService;
|
this.sysParameterService = sysParameterService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,9 +28,9 @@ public class UserServiceImpl implements UserService {
|
|||||||
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||||
if (user != null) {
|
if (user != null) {
|
||||||
SystemConfig systemConfig = sysParameterService.getSystemConfig();
|
SysParameter sysParameter = sysParameterService.getSysParameter();
|
||||||
if (!CollectionUtils.isEmpty(systemConfig.getAdmins())
|
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
|
||||||
&& systemConfig.getAdmins().contains(user.getName())) {
|
&& sysParameter.getAdmins().contains(user.getName())) {
|
||||||
user.setIsAdmin(1);
|
user.setIsAdmin(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public class SchemaMapInfo {
|
||||||
|
|
||||||
|
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
|
||||||
|
|
||||||
|
public Set<Long> getMatchedViewInfos() {
|
||||||
|
return viewElementMatches.keySet();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElementMatch> getMatchedElements(Long view) {
|
||||||
|
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
|
||||||
|
return viewElementMatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
|
||||||
|
this.viewElementMatches = viewElementMatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
|
||||||
|
viewElementMatches.put(view, elementMatches);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,11 +1,15 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -22,12 +26,12 @@ public class SemanticParseInfo {
|
|||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode;
|
private String queryMode;
|
||||||
private SchemaElement dataSet;
|
private SchemaElement view;
|
||||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||||
private SchemaElement entity;
|
private SchemaElement entity;
|
||||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||||
private FilterType filterType = FilterType.AND;
|
private FilterType filterType = FilterType.UNION;
|
||||||
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
|
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
|
||||||
private Set<QueryFilter> metricFilters = new LinkedHashSet();
|
private Set<QueryFilter> metricFilters = new LinkedHashSet();
|
||||||
private Set<Order> orders = new LinkedHashSet();
|
private Set<Order> orders = new LinkedHashSet();
|
||||||
@@ -36,10 +40,10 @@ public class SemanticParseInfo {
|
|||||||
private double score;
|
private double score;
|
||||||
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||||
private Map<String, Object> properties = new HashMap<>();
|
private Map<String, Object> properties = new HashMap<>();
|
||||||
|
private EntityInfo entityInfo;
|
||||||
private SqlInfo sqlInfo = new SqlInfo();
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
private QueryType queryType = QueryType.ID;
|
private QueryType queryType = QueryType.ID;
|
||||||
private EntityInfo entityInfo;
|
|
||||||
private String textInfo;
|
|
||||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -68,11 +72,15 @@ public class SemanticParseInfo {
|
|||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getDataSetId() {
|
public Long getViewId() {
|
||||||
if (dataSet == null) {
|
if (view == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return dataSet.getDataSet();
|
return view.getView();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getModel() {
|
||||||
|
return view;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,157 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
public class SemanticSchema implements Serializable {
|
||||||
|
|
||||||
|
private List<ViewSchema> viewSchemaList;
|
||||||
|
|
||||||
|
public SemanticSchema(List<ViewSchema> viewSchemaList) {
|
||||||
|
this.viewSchemaList = viewSchemaList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void add(ViewSchema schema) {
|
||||||
|
viewSchemaList.add(schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = getElementsById(elementID, getEntities());
|
||||||
|
break;
|
||||||
|
case VIEW:
|
||||||
|
element = getElementsById(elementID, getViews());
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = getElementsById(elementID, getMetrics());
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = getElementsById(elementID, getDimensions());
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = getElementsById(elementID, getDimensionValues());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Long, String> getViewIdToName() {
|
||||||
|
return viewSchemaList.stream()
|
||||||
|
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getDimensionValues() {
|
||||||
|
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||||
|
return dimensionValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getDimensions() {
|
||||||
|
List<SchemaElement> dimensions = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getDimensions(Long viewId) {
|
||||||
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
|
return getElementsByViewId(viewId, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getDimension(Long id) {
|
||||||
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
|
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||||
|
return dimension.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTags() {
|
||||||
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||||
|
return tags;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTags(Long viewId) {
|
||||||
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().filter(schemaElement ->
|
||||||
|
viewId.equals(schemaElement.getView().getView()))
|
||||||
|
.forEach(d -> tags.addAll(d.getTags()));
|
||||||
|
return tags;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getMetrics() {
|
||||||
|
List<SchemaElement> metrics = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||||
|
return metrics;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getMetrics(Long viewId) {
|
||||||
|
List<SchemaElement> metrics = getMetrics();
|
||||||
|
return getElementsByViewId(viewId, metrics);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getEntities() {
|
||||||
|
List<SchemaElement> entities = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||||
|
return entities;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getEntities(Long viewId) {
|
||||||
|
List<SchemaElement> entities = getEntities();
|
||||||
|
return getElementsByViewId(viewId, entities);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> id.equals(schemaElement.getId()))
|
||||||
|
.findFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getView(Long viewId) {
|
||||||
|
List<SchemaElement> views = getViews();
|
||||||
|
return getElementsById(viewId, views).orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getViews() {
|
||||||
|
List<SchemaElement> views = new ArrayList<>();
|
||||||
|
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
|
||||||
|
return views;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, String> getBizNameToName(Long viewId) {
|
||||||
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
|
allElements.addAll(getDimensions(viewId));
|
||||||
|
allElements.addAll(getMetrics(viewId));
|
||||||
|
return allElements.stream()
|
||||||
|
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Long, ViewSchema> getViewSchemaMap() {
|
||||||
|
if (CollectionUtils.isEmpty(viewSchemaList)) {
|
||||||
|
return new HashMap<>();
|
||||||
|
}
|
||||||
|
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
|
||||||
|
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,18 +1,24 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DataSetSchema {
|
public class ViewSchema {
|
||||||
private SchemaElement dataSet;
|
|
||||||
|
private SchemaElement view;
|
||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
private Set<SchemaElement> tags = new HashSet<>();
|
|
||||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||||
private Set<SchemaElement> terms = new HashSet<>();
|
private Set<SchemaElement> tags = new HashSet<>();
|
||||||
private SchemaElement entity = new SchemaElement();
|
private SchemaElement entity = new SchemaElement();
|
||||||
private QueryConfig queryConfig;
|
private QueryConfig queryConfig;
|
||||||
|
|
||||||
@@ -23,8 +29,8 @@ public class DataSetSchema {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = Optional.ofNullable(entity);
|
element = Optional.ofNullable(entity);
|
||||||
break;
|
break;
|
||||||
case DATASET:
|
case VIEW:
|
||||||
element = Optional.of(dataSet);
|
element = Optional.of(view);
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
@@ -38,8 +44,34 @@ public class DataSetSchema {
|
|||||||
case TAG:
|
case TAG:
|
||||||
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
|
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
break;
|
break;
|
||||||
case TERM:
|
default:
|
||||||
element = terms.stream().filter(e -> e.getId() == elementID).findFirst();
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElement(SchemaElementType elementType, String name) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = Optional.ofNullable(entity);
|
||||||
|
break;
|
||||||
|
case VIEW:
|
||||||
|
element = Optional.of(view);
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -16,6 +16,16 @@ public class ChatConfigBaseReq {
|
|||||||
|
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the chatDetailConfig about the model
|
||||||
|
*/
|
||||||
|
private ChatDetailConfigReq chatDetailConfig;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the chatAggConfig about the model
|
||||||
|
*/
|
||||||
|
private ChatAggConfigReq chatAggConfig;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the recommended questions about the model
|
* the recommended questions about the model
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class ChatExecuteReq {
|
|
||||||
private User user;
|
|
||||||
private Long queryId;
|
|
||||||
private Integer chatId;
|
|
||||||
private int parseId;
|
|
||||||
private String queryText;
|
|
||||||
private boolean saveAnswer;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class ChatQueryDataReq {
|
|
||||||
private User user;
|
|
||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
|
||||||
private Set<QueryFilter> dimensionFilters = new HashSet<>();
|
|
||||||
private Set<QueryFilter> metricFilters = new HashSet<>();
|
|
||||||
private DateConf dateInfo;
|
|
||||||
private Long queryId;
|
|
||||||
private Integer parseId;
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -10,9 +10,11 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class ExecuteQueryReq {
|
public class ExecuteQueryReq {
|
||||||
private User user;
|
private User user;
|
||||||
private Long queryId;
|
private Integer agentId;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
private Long queryId;
|
||||||
|
private Integer parseId;
|
||||||
private SemanticParseInfo parseInfo;
|
private SemanticParseInfo parseInfo;
|
||||||
private boolean saveAnswer;
|
private boolean saveAnswer;
|
||||||
}
|
}
|
||||||
@@ -13,7 +13,7 @@ public class PluginQueryReq {
|
|||||||
|
|
||||||
private String type;
|
private String type;
|
||||||
|
|
||||||
private String dataSet;
|
private String view;
|
||||||
|
|
||||||
private String pattern;
|
private String pattern;
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryDataReq {
|
public class QueryDataReq {
|
||||||
@@ -19,5 +17,5 @@ public class QueryDataReq {
|
|||||||
private Set<QueryFilter> metricFilters = new HashSet<>();
|
private Set<QueryFilter> metricFilters = new HashSet<>();
|
||||||
private DateConf dateInfo;
|
private DateConf dateInfo;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private SemanticParseInfo parseInfo;
|
private Integer parseId;
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -1,19 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ChatParseReq {
|
public class QueryReq {
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private Integer agentId;
|
private Long modelId;
|
||||||
private Integer topN = 10;
|
|
||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private Integer agentId;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -18,7 +18,7 @@ public class SimilarQueryReq {
|
|||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
|
||||||
private Long dataSetId;
|
private Long viewId;
|
||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class AggregateInfo {
|
public class AggregateInfo {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -8,7 +8,7 @@ import java.util.List;
|
|||||||
@Data
|
@Data
|
||||||
public class EntityInfo {
|
public class EntityInfo {
|
||||||
|
|
||||||
private DataSetInfo dataSetInfo = new DataSetInfo();
|
private ViewInfo viewInfo = new ViewInfo();
|
||||||
private List<DataInfo> dimensions = new ArrayList<>();
|
private List<DataInfo> dimensions = new ArrayList<>();
|
||||||
private List<DataInfo> metrics = new ArrayList<>();
|
private List<DataInfo> metrics = new ArrayList<>();
|
||||||
private String entityId;
|
private String entityId;
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MetricInfo {
|
public class MetricInfo {
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import lombok.Data;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class ParseResp {
|
||||||
|
private Integer chatId;
|
||||||
|
private String queryText;
|
||||||
|
private Long queryId;
|
||||||
|
private ParseState state;
|
||||||
|
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||||
|
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
|
||||||
|
|
||||||
|
public enum ParseState {
|
||||||
|
COMPLETED,
|
||||||
|
PENDING,
|
||||||
|
FAILED
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ParseTimeCostResp {
|
public class ParseTimeCostDO {
|
||||||
|
|
||||||
private long parseStartTime;
|
private long parseStartTime;
|
||||||
private long parseTime;
|
private long parseTime;
|
||||||
private long sqlTime;
|
private long sqlTime;
|
||||||
|
|
||||||
public ParseTimeCostResp() {
|
public ParseTimeCostDO() {
|
||||||
this.parseStartTime = System.currentTimeMillis();
|
this.parseStartTime = System.currentTimeMillis();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,10 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryResp {
|
public class QueryResp {
|
||||||
|
|
||||||
@@ -20,4 +18,5 @@ public class QueryResp {
|
|||||||
private List<SemanticParseInfo> parseInfos;
|
private List<SemanticParseInfo> parseInfos;
|
||||||
private List<SimilarQueryRecallResp> similarQueries;
|
private List<SimilarQueryRecallResp> similarQueries;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryResult {
|
public class QueryResult {
|
||||||
|
|
||||||
|
public EntityInfo entityInfo;
|
||||||
|
public AggregateInfo aggregateInfo;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private String queryMode;
|
private String queryMode;
|
||||||
private String querySql;
|
private String querySql;
|
||||||
@@ -22,9 +22,6 @@ public class QueryResult {
|
|||||||
private SemanticParseInfo chatContext;
|
private SemanticParseInfo chatContext;
|
||||||
private Object response;
|
private Object response;
|
||||||
private List<Map<String, Object>> queryResults;
|
private List<Map<String, Object>> queryResults;
|
||||||
private String textResult;
|
|
||||||
private Long queryTimeCost;
|
private Long queryTimeCost;
|
||||||
private EntityInfo entityInfo;
|
|
||||||
private List<SchemaElement> recommendedDimensions;
|
private List<SchemaElement> recommendedDimensions;
|
||||||
private AggregateInfo aggregateInfo;
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
public enum QueryState {
|
public enum QueryState {
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,13 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import java.util.Objects;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Setter
|
@Setter
|
||||||
@Getter
|
@Getter
|
||||||
@@ -10,6 +10,8 @@ public class SimilarQueryRecallResp {
|
|||||||
|
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
|
|
||||||
|
private Integer parseId;
|
||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -6,7 +6,7 @@ import java.io.Serializable;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DataSetInfo extends DataInfo implements Serializable {
|
public class ViewInfo extends DataInfo implements Serializable {
|
||||||
|
|
||||||
private List<String> words;
|
private List<String> words;
|
||||||
private String primaryKey;
|
private String primaryKey;
|
||||||
110
chat/core/pom.xml
Normal file
110
chat/core/pom.xml
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<parent>
|
||||||
|
<artifactId>chat</artifactId>
|
||||||
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
|
<version>${revision}</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>chat-core</artifactId>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<maven.compiler.source>8</maven.compiler.source>
|
||||||
|
<maven.compiler.target>8</maven.compiler.target>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework</groupId>
|
||||||
|
<artifactId>spring-context</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.testng</groupId>
|
||||||
|
<artifactId>testng</artifactId>
|
||||||
|
<version>${org.testng.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-compress</artifactId>
|
||||||
|
<version>${commons.compress.version}</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-test</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.alibaba</groupId>
|
||||||
|
<artifactId>druid</artifactId>
|
||||||
|
<version>${alibaba.druid.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>mysql</groupId>
|
||||||
|
<artifactId>mysql-connector-java</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.h2database</groupId>
|
||||||
|
<artifactId>h2</artifactId>
|
||||||
|
<version>${h2.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
|
<artifactId>headless-api</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
|
<artifactId>headless-core</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
|
<artifactId>chat-api</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.xkzhangsan</groupId>
|
||||||
|
<artifactId>xk-time</artifactId>
|
||||||
|
<version>${xk.time.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-inline</artifactId>
|
||||||
|
<version>${mockito-inline.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
|
<artifactId>headless-server</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
</project>
|
||||||
@@ -1,10 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -31,8 +29,6 @@ public class Agent extends RecordInfo {
|
|||||||
private Integer status;
|
private Integer status;
|
||||||
private List<String> examples;
|
private List<String> examples;
|
||||||
private String agentConfig;
|
private String agentConfig;
|
||||||
private LLMConfig llmConfig;
|
|
||||||
private MultiTurnConfig multiTurnConfig;
|
|
||||||
|
|
||||||
public List<String> getTools(AgentToolType type) {
|
public List<String> getTools(AgentToolType type) {
|
||||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||||
@@ -69,33 +65,12 @@ public class Agent extends RecordInfo {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean containsLLMParserTool() {
|
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean containsRuleTool() {
|
|
||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean containsNL2SQLTool() {
|
|
||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|
|
||||||
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
|
||||||
}
|
|
||||||
|
|
||||||
public Set<Long> getDataSetIds() {
|
|
||||||
Set<Long> dataSetIds = getDataSetIds(null);
|
|
||||||
if (containsAllModel(dataSetIds)) {
|
|
||||||
return Sets.newHashSet();
|
|
||||||
}
|
|
||||||
return dataSetIds;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
|
|
||||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
|
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
|
||||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||||
.flatMap(Collection::stream)
|
.flatMap(Collection::stream)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@@ -12,6 +12,6 @@ import java.util.List;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class NL2SQLTool extends AgentTool {
|
public class NL2SQLTool extends AgentTool {
|
||||||
|
|
||||||
protected List<Long> dataSetIds;
|
protected List<Long> viewIds;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
|
|||||||
private List<String> queryTypes;
|
private List<String> queryTypes;
|
||||||
|
|
||||||
public boolean isContainsAllModel() {
|
public boolean isContainsAllModel() {
|
||||||
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
|
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Configuration
|
@Configuration
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@ToString
|
@ToString
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* when query an entity, return related dimension/metric info
|
* when query an entity, return related dimension/metric info
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class EntityInternalDetail {
|
public class EntityInternalDetail {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -9,21 +9,19 @@ import org.springframework.context.annotation.Configuration;
|
|||||||
@Data
|
@Data
|
||||||
public class LLMParserConfig {
|
public class LLMParserConfig {
|
||||||
|
|
||||||
@Value("${s2.parser.url:}")
|
|
||||||
|
@Value("${llm.parser.url:}")
|
||||||
private String url;
|
private String url;
|
||||||
|
|
||||||
@Value("${s2.query2sql.path:/query2sql}")
|
@Value("${query2sql.path:/query2sql}")
|
||||||
private String queryToSqlPath;
|
private String queryToSqlPath;
|
||||||
|
|
||||||
@Value("${s2.dimension.topn:10}")
|
@Value("${dimension.topn:10}")
|
||||||
private Integer dimensionTopN;
|
private Integer dimensionTopN;
|
||||||
|
|
||||||
@Value("${s2.metric.topn:10}")
|
@Value("${metric.topn:10}")
|
||||||
private Integer metricTopN;
|
private Integer metricTopN;
|
||||||
|
|
||||||
@Value("${s2.tag.topn:20}")
|
@Value("${all.model:false}")
|
||||||
private Integer tagTopN;
|
|
||||||
|
|
||||||
@Value("${s2.all.model:false}")
|
|
||||||
private Boolean allModel;
|
private Boolean allModel;
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@Data
|
||||||
|
@Slf4j
|
||||||
|
public class OptimizationConfig {
|
||||||
|
|
||||||
|
@Value("${one.detection.size:8}")
|
||||||
|
private Integer oneDetectionSize;
|
||||||
|
|
||||||
|
@Value("${one.detection.max.size:20}")
|
||||||
|
private Integer oneDetectionMaxSize;
|
||||||
|
|
||||||
|
@Value("${metric.dimension.min.threshold:0.3}")
|
||||||
|
private Double metricDimensionMinThresholdConfig;
|
||||||
|
|
||||||
|
@Value("${metric.dimension.threshold:0.3}")
|
||||||
|
private Double metricDimensionThresholdConfig;
|
||||||
|
|
||||||
|
@Value("${dimension.value.threshold:0.5}")
|
||||||
|
private Double dimensionValueThresholdConfig;
|
||||||
|
|
||||||
|
@Value("${long.text.threshold:0.8}")
|
||||||
|
private Double longTextThreshold;
|
||||||
|
|
||||||
|
@Value("${short.text.threshold:0.5}")
|
||||||
|
private Double shortTextThreshold;
|
||||||
|
|
||||||
|
@Value("${query.text.length.threshold:10}")
|
||||||
|
private Integer queryTextLengthThreshold;
|
||||||
|
@Value("${embedding.mapper.word.min:4}")
|
||||||
|
private int embeddingMapperWordMin;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.word.max:5}")
|
||||||
|
private int embeddingMapperWordMax;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.batch:50}")
|
||||||
|
private int embeddingMapperBatch;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.number:5}")
|
||||||
|
private int embeddingMapperNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||||
|
private Double embeddingMapperDistanceThreshold;
|
||||||
|
|
||||||
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
|
private boolean useLinkingValueSwitch;
|
||||||
|
|
||||||
|
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
|
||||||
|
private SqlGenerationMode sqlGenerationMode;
|
||||||
|
|
||||||
|
@Value("${s2SQL.use.switch:true}")
|
||||||
|
private boolean useS2SqlSwitch;
|
||||||
|
|
||||||
|
@Value("${text2sql.example.num:15}")
|
||||||
|
private int text2sqlExampleNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.fewShots.num:10}")
|
||||||
|
private int text2sqlFewShotsNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.self.consistency.num:5}")
|
||||||
|
private int text2sqlSelfConsistencyNum;
|
||||||
|
|
||||||
|
@Value("${parse.show.count:3}")
|
||||||
|
private Integer parseShowCount;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SysParameterService sysParameterService;
|
||||||
|
|
||||||
|
public Integer getOneDetectionSize() {
|
||||||
|
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getOneDetectionMaxSize() {
|
||||||
|
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMetricDimensionMinThresholdConfig() {
|
||||||
|
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMetricDimensionThresholdConfig() {
|
||||||
|
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getDimensionValueThresholdConfig() {
|
||||||
|
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getLongTextThreshold() {
|
||||||
|
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getShortTextThreshold() {
|
||||||
|
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getQueryTextLengthThreshold() {
|
||||||
|
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isUseS2SqlSwitch() {
|
||||||
|
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperWordMin() {
|
||||||
|
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperWordMax() {
|
||||||
|
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperBatch() {
|
||||||
|
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperNumber() {
|
||||||
|
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperRoundNumber() {
|
||||||
|
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getEmbeddingMapperDistanceThreshold() {
|
||||||
|
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isUseLinkingValueSwitch() {
|
||||||
|
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SqlGenerationMode getSqlGenerationMode() {
|
||||||
|
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getParseShowCount() {
|
||||||
|
return convertValue("parse.show.count", Integer.class, parseShowCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||||
|
try {
|
||||||
|
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||||
|
if (StringUtils.isBlank(value)) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
if (targetType == Double.class) {
|
||||||
|
return targetType.cast(Double.parseDouble(value));
|
||||||
|
} else if (targetType == Integer.class) {
|
||||||
|
return targetType.cast(Integer.parseInt(value));
|
||||||
|
} else if (targetType == Boolean.class) {
|
||||||
|
return targetType.cast(Boolean.parseBoolean(value));
|
||||||
|
} else if (targetType == SqlGenerationMode.class) {
|
||||||
|
return targetType.cast(SqlGenerationMode.valueOf(value));
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("convertValue", e);
|
||||||
|
}
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
@@ -37,7 +37,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
doCorrect(queryContext, semanticParseInfo);
|
doCorrect(queryContext, semanticParseInfo);
|
||||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||||
}
|
}
|
||||||
@@ -45,7 +45,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
|
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
// support fieldName and field alias
|
// support fieldName and field alias
|
||||||
Map<String, String> result = dbAllFields.stream()
|
Map<String, String> result = dbAllFields.stream()
|
||||||
.filter(entry -> dataSetId.equals(entry.getDataSet()))
|
.filter(entry -> viewId.equals(entry.getView()))
|
||||||
.flatMap(schemaElement -> {
|
.flatMap(schemaElement -> {
|
||||||
Set<String> elements = new HashSet<>();
|
Set<String> elements = new HashSet<>();
|
||||||
elements.add(schemaElement.getName());
|
elements.add(schemaElement.getName());
|
||||||
@@ -82,7 +82,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
//decide whether add order by expression field to select
|
//decide whether add order by expression field to select
|
||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
}
|
}
|
||||||
@@ -109,8 +109,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
Long viewId = semanticParseInfo.getView().getView();
|
||||||
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
|
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
|
||||||
|
|
||||||
Map<String, String> metricToAggregate = metrics.stream()
|
Map<String, String> metricToAggregate = metrics.stream()
|
||||||
.map(schemaElement -> {
|
.map(schemaElement -> {
|
||||||
@@ -135,24 +135,9 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
|
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
return semanticSchema.getMetrics(dataSetId);
|
return semanticSchema.getMetrics(viewId);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
|
|
||||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
|
||||||
.flatMap(
|
|
||||||
schemaElement -> {
|
|
||||||
Set<String> elements = new HashSet<>();
|
|
||||||
elements.add(schemaElement.getName());
|
|
||||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
|
||||||
elements.addAll(schemaElement.getAlias());
|
|
||||||
}
|
|
||||||
return elements.stream();
|
|
||||||
}
|
|
||||||
).collect(Collectors.toSet());
|
|
||||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
|
||||||
return dimensions;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -1,18 +1,25 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.core.env.Environment;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -32,7 +39,23 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
|
ViewService viewService = ContextUtils.getBean(ViewService.class);
|
||||||
|
ModelService modelService = ContextUtils.getBean(ModelService.class);
|
||||||
|
ViewResp viewResp = viewService.getView(viewId);
|
||||||
|
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
MetaFilter metaFilter = new MetaFilter();
|
||||||
|
metaFilter.setIds(modelIds);
|
||||||
|
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
|
||||||
|
for (ModelResp modelResp : modelRespList) {
|
||||||
|
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
|
||||||
|
for (Dim dim : dimList) {
|
||||||
|
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
@@ -43,7 +66,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
//add alias field name
|
//add alias field name
|
||||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||||
return false;
|
return false;
|
||||||
@@ -56,22 +79,33 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
|
||||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
|
||||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && !Boolean.parseBoolean(correctorAdditionalInfo)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||||
|
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||||
|
.flatMap(
|
||||||
|
schemaElement -> {
|
||||||
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
}
|
||||||
|
).collect(Collectors.toSet());
|
||||||
|
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
//add alias field name
|
//add alias field name
|
||||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
Set<String> groupByFields = selectFields.stream()
|
Set<String> groupByFields = selectFields.stream()
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -31,7 +31,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
//decide whether add having expression field to select
|
//decide whether add having expression field to select
|
||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
addHavingToSelect(semanticParseInfo);
|
addHavingToSelect(semanticParseInfo);
|
||||||
}
|
}
|
||||||
@@ -39,11 +39,11 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
|
Long viewId = semanticParseInfo.getView().getView();
|
||||||
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(metrics)) {
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
@@ -1,26 +1,17 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.ParseResult;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -61,7 +52,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
|
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
@@ -114,35 +105,4 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
|
||||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
|
||||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
|
||||||
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
|
||||||
linkingValues = new ArrayList<>();
|
|
||||||
}
|
|
||||||
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
|
|
||||||
Set<String> removeFieldNames = whereExpressionList.stream()
|
|
||||||
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
|
|
||||||
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
|
|
||||||
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
|
|
||||||
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
|
|
||||||
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
|
|
||||||
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
|
|
||||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
|
||||||
|
|
||||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,15 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import java.util.List;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform SQL corrections on the "Select" section in S2SQL.
|
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||||
*/
|
*/
|
||||||
@@ -28,7 +25,5 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||||
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic corrector checks validity of extracted semantic information and
|
* A semantic corrector checks validity of extracted semantic information and
|
||||||
@@ -1,25 +1,16 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import java.util.Objects;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform SQL corrections on the time in S2SQL.
|
* Perform SQL corrections on the time in S2SQL.
|
||||||
@@ -30,39 +21,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
|||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
|
||||||
|
|
||||||
parserDateDiffFunction(semanticParseInfo);
|
parserDateDiffFunction(semanticParseInfo);
|
||||||
|
|
||||||
addLowerBoundDate(semanticParseInfo);
|
addLowerBoundDate(semanticParseInfo);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
|
||||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
|
||||||
|
|
||||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
|
||||||
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
|
|
||||||
|
|
||||||
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
|
||||||
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
|
||||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
|
||||||
String dateChName = TimeDimensionEnum.DAY.getChName();
|
|
||||||
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
|
|
||||||
startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
|
||||||
try {
|
|
||||||
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
|
|
||||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
|
||||||
} catch (JSQLParserException e) {
|
|
||||||
log.error("parseCondExpression:{}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||||
@@ -1,21 +1,24 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -34,6 +37,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
addQueryFilter(queryContext, semanticParseInfo);
|
addQueryFilter(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||||
@@ -57,6 +62,29 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
|
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
||||||
|
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
|
||||||
|
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||||
|
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
||||||
|
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
String dateChName = TimeDimensionEnum.DAY.getChName();
|
||||||
|
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
|
||||||
|
startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
||||||
|
try {
|
||||||
|
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||||
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("parseCondExpression:{}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
|
||||||
private String getQueryFilter(QueryFilters queryFilters) {
|
private String getQueryFilter(QueryFilters queryFilters) {
|
||||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
return null;
|
return null;
|
||||||
@@ -73,8 +101,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
return;
|
return;
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseMapper implements SchemaMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void map(QueryContext queryContext) {
|
||||||
|
|
||||||
|
String simpleName = this.getClass().getSimpleName();
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
|
||||||
|
|
||||||
|
try {
|
||||||
|
doMap(queryContext);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("work error", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
|
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void doMap(QueryContext queryContext);
|
||||||
|
|
||||||
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||||
|
if (schemaElementMatches == null) {
|
||||||
|
schemaElementMatches = modelElementMatches.get(modelId);
|
||||||
|
}
|
||||||
|
//remove duplication
|
||||||
|
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||||
|
schemaElementMatches.removeIf(
|
||||||
|
existElementMatch -> {
|
||||||
|
SchemaElement existElement = existElementMatch.getElement();
|
||||||
|
SchemaElement newElement = newElementMatch.getElement();
|
||||||
|
if (existElement.equals(newElement)) {
|
||||||
|
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
needAddNew.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (needAddNew.get()) {
|
||||||
|
schemaElementMatches.add(newElementMatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
|
||||||
|
SemanticSchema semanticSchema) {
|
||||||
|
SchemaElement element = new SchemaElement();
|
||||||
|
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||||
|
if (Objects.isNull(viewSchema)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
|
||||||
|
if (Objects.isNull(elementDb)) {
|
||||||
|
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
BeanUtils.copyProperties(elementDb, element);
|
||||||
|
element.setAlias(getAlias(elementDb));
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getAlias(SchemaElement element) {
|
||||||
|
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||||
|
element.getName())) {
|
||||||
|
return element.getAlias().stream()
|
||||||
|
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,17 +1,8 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.config.MapperConfig;
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -22,35 +13,37 @@ import java.util.Objects;
|
|||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
protected MapperHelper mapperHelper;
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
@Autowired
|
|
||||||
protected MapperConfig mapperConfig;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectViewIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
|
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
|
||||||
|
|
||||||
List<T> detects = detect(queryContext, terms, detectDataSetIds);
|
List<T> detects = detect(queryContext, terms, detectViewIds);
|
||||||
Map<MatchText, List<T>> result = new HashMap<>();
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
Set<T> results = new HashSet<>();
|
Set<T> results = new HashSet<>();
|
||||||
@@ -65,17 +58,18 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index).trim();
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
detectSegments.add(detectSegment);
|
detectSegments.add(detectSegment);
|
||||||
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
|
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
}
|
}
|
||||||
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
|
detectByBatch(queryContext, results, detectViewIds, detectSegments);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
|
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||||
@@ -110,9 +104,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
||||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||||
terms = filterByDataSetId(terms, dataSetIds);
|
terms = filterByViewId(terms, viewIds);
|
||||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
|
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
|
||||||
List<T> matches = new ArrayList<>();
|
List<T> matches = new ArrayList<>();
|
||||||
if (Objects.isNull(matchResult)) {
|
if (Objects.isNull(matchResult)) {
|
||||||
return matches;
|
return matches;
|
||||||
@@ -127,17 +121,17 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
return matches;
|
return matches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
|
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
|
||||||
logTerms(terms);
|
logTerms(terms);
|
||||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
if (CollectionUtils.isNotEmpty(viewIds)) {
|
||||||
terms = terms.stream().filter(term -> {
|
terms = terms.stream().filter(term -> {
|
||||||
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
|
Long viewId = NatureHelper.getViewId(term.getNature().toString());
|
||||||
if (Objects.nonNull(dataSetId)) {
|
if (Objects.nonNull(viewId)) {
|
||||||
return dataSetIds.contains(dataSetId);
|
return viewIds.contains(viewId);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
log.info("terms filter by dataSetId:{}", dataSetIds);
|
log.info("terms filter by viewId:{}", viewIds);
|
||||||
logTerms(terms);
|
logTerms(terms);
|
||||||
}
|
}
|
||||||
return terms;
|
return terms;
|
||||||
@@ -156,12 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
public abstract String getMapKey(T a);
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
|
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||||
String detectSegment, int offset);
|
String detectSegment, int offset);
|
||||||
|
|
||||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
|
||||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
|
||||||
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
|
|
||||||
return divideThreshold >= minThreshold ? divideThreshold : minThreshold;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -20,9 +21,6 @@ import java.util.Map.Entry;
|
|||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||||
* It currently supports fuzzy matching against names and aliases.
|
* It currently supports fuzzy matching against names and aliases.
|
||||||
@@ -31,13 +29,17 @@ import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NA
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
private List<SchemaElement> allElements;
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectViewIds) {
|
||||||
this.allElements = getSchemaElements(queryContext);
|
this.allElements = getSchemaElements(queryContext);
|
||||||
return super.match(queryContext, terms, detectDataSetIds);
|
return super.match(queryContext, terms, detectViewIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -52,8 +54,8 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
|
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||||
String detectSegment, int offset) {
|
String detectSegment, int offset) {
|
||||||
if (StringUtils.isBlank(detectSegment)) {
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -68,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||||
schemaElements = schemaElements.stream()
|
schemaElements = schemaElements.stream()
|
||||||
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
|
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
for (SchemaElement schemaElement : schemaElements) {
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
@@ -91,19 +93,22 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Double getThreshold(QueryContext queryContext) {
|
private Double getThreshold(QueryContext queryContext) {
|
||||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
|
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
|
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||||
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
|
||||||
|
|
||||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||||
|
|
||||||
if (!existElement) {
|
if (!existElement) {
|
||||||
threshold = threshold / 2;
|
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||||
log.info("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
|
||||||
modelElementMatches, threshold);
|
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||||
|
: metricDimensionMinThresholdConfig;
|
||||||
|
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||||
|
modelElementMatches, metricDimensionThresholdConfig);
|
||||||
}
|
}
|
||||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
return metricDimensionThresholdConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||||
@@ -1,18 +1,17 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@@ -26,7 +25,8 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
//1. query from embedding by queryText
|
//1. query from embedding by queryText
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
|
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||||
|
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||||
|
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||||
@@ -36,12 +36,12 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
//2. build SchemaElementMatch by info
|
//2. build SchemaElementMatch by info
|
||||||
for (EmbeddingResult matchResult : matchResults) {
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
|
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||||
if (Objects.isNull(dataSetId)) {
|
if (Objects.isNull(viewId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||||
queryContext.getSemanticSchema());
|
queryContext.getSemanticSchema());
|
||||||
if (schemaElement == null) {
|
if (schemaElement == null) {
|
||||||
continue;
|
continue;
|
||||||
@@ -54,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
.detectWord(matchResult.getDetectWord())
|
.detectWord(matchResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
//3. add to mapInfo
|
//3. add to mapInfo
|
||||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,20 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -15,20 +22,6 @@ import org.springframework.beans.BeanUtils;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_BATCH;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MAX;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MIN;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
|
|
||||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* EmbeddingMatchStrategy uses vector database to perform
|
* EmbeddingMatchStrategy uses vector database to perform
|
||||||
* similarity search against the embeddings of schema elements.
|
* similarity search against the embeddings of schema elements.
|
||||||
@@ -37,6 +30,9 @@ import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private MetaEmbeddingService metaEmbeddingService;
|
private MetaEmbeddingService metaEmbeddingService;
|
||||||
|
|
||||||
@@ -52,47 +48,39 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
String detectSegment, int offset) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MIN));
|
|
||||||
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MAX));
|
|
||||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_BATCH));
|
|
||||||
|
|
||||||
List<String> queryTextsList = detectSegments.stream()
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
.map(detectSegment -> detectSegment.trim())
|
.map(detectSegment -> detectSegment.trim())
|
||||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||||
&& detectSegment.length() >= embedddingMapperMin
|
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||||
&& detectSegment.length() <= embedddingMapperMax)
|
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||||
embeddingMapperBatch);
|
optimizationConfig.getEmbeddingMapperBatch());
|
||||||
|
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
|
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
List<String> queryTextsSub, QueryContext queryContext) {
|
List<String> queryTextsSub) {
|
||||||
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
|
||||||
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
|
|
||||||
|
|
||||||
// step1. build query params
|
// step1. build query params
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||||
|
|
||||||
// step2. retrieveQuery by detectSegment
|
// step2. retrieveQuery by detectSegment
|
||||||
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
|
||||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
return;
|
return;
|
||||||
@@ -102,12 +90,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
.map(retrieveQueryResult -> {
|
.map(retrieveQueryResult -> {
|
||||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
retrievals.removeIf(retrieval -> {
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
|
||||||
return retrieval.getDistance() > 1 - threshold;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return retrieveQueryResult;
|
return retrieveQueryResult;
|
||||||
})
|
})
|
||||||
@@ -126,8 +109,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
// step4. select mapResul in one round
|
// step4. select mapResul in one round
|
||||||
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
|
||||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||||
.limit(roundNumber)
|
.limit(roundNumber)
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -23,18 +24,18 @@ public class EntityMapper extends BaseMapper {
|
|||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
|
||||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
|
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
|
||||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElement entity = getEntity(dataSetId, queryContext);
|
SchemaElement entity = getEntity(viewId, queryContext);
|
||||||
if (entity == null || entity.getId() == null) {
|
if (entity == null || entity.getId() == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||||
.filter(schemaElementMatch -> SchemaElementType.VALUE.equals(
|
.filter(schemaElementMatch ->
|
||||||
schemaElementMatch.getElement().getType()))
|
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||||
@@ -64,9 +65,9 @@ public class EntityMapper extends BaseMapper {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
|
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||||
return modelSchema.getEntity();
|
return modelSchema.getEntity();
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
|
||||||
|
* match schema elements. It currently supports prefix and suffix matching
|
||||||
|
* against names, values and aliases.
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private KnowledgeService knowledgeService;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
|
Set<Long> detectViewIds) {
|
||||||
|
String text = queryContext.getQueryText();
|
||||||
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
|
||||||
|
|
||||||
|
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
|
||||||
|
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||||
|
|
||||||
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
|
String detectSegment, int offset) {
|
||||||
|
// step1. pre search
|
||||||
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||||
|
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
// step2. suffix search
|
||||||
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||||
|
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// step3. merge pre/suffix result
|
||||||
|
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
// step4. filter by similarity
|
||||||
|
hanlpMapResults = hanlpMapResults.stream()
|
||||||
|
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||||
|
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||||
|
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
|
||||||
|
|
||||||
|
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||||
|
parseResult.setOffset(offset);
|
||||||
|
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||||
|
return parseResult;
|
||||||
|
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||||
|
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
|
||||||
|
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
.stream()
|
||||||
|
.limit(1)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
|
||||||
|
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||||
|
oneRoundResults = dimensionMetrics;
|
||||||
|
}
|
||||||
|
// step6. select mapResul in one round
|
||||||
|
selectResultInOneRound(existResults, oneRoundResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMapKey(HanlpMapResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -33,7 +33,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
//1.hanlpDict Match
|
//1.hanlpDict Match
|
||||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
|
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||||
|
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||||
|
|
||||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||||
@@ -47,7 +48,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||||
List<S2Term> terms) {
|
List<S2Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(mapResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -58,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
|
|
||||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||||
for (String nature : hanlpMapResult.getNatures()) {
|
for (String nature : hanlpMapResult.getNatures()) {
|
||||||
Long dataSetId = NatureHelper.getDataSetId(nature);
|
Long viewId = NatureHelper.getViewId(nature);
|
||||||
if (Objects.isNull(dataSetId)) {
|
if (Objects.isNull(viewId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||||
@@ -67,11 +68,14 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
SchemaElement element = getSchemaElement(dataSetId, elementType,
|
SchemaElement element = getSchemaElement(viewId, elementType,
|
||||||
elementID, queryContext.getSemanticSchema());
|
elementID, queryContext.getSemanticSchema());
|
||||||
if (element == null) {
|
if (element == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||||
|
element.setName(hanlpMapResult.getName());
|
||||||
|
}
|
||||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
@@ -81,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.detectWord(hanlpMapResult.getDetectWord())
|
.detectWord(hanlpMapResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,16 +102,16 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.element(schemaElement)
|
.element(schemaElement)
|
||||||
.word(schemaElement.getName())
|
.word(schemaElement.getName())
|
||||||
.detectWord(match.getDetectWord())
|
.detectWord(match.getDetectWord())
|
||||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
.frequency(10000L)
|
||||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||||
.build();
|
.build();
|
||||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
|
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
|
||||||
if (CollectionUtils.isEmpty(elements)) {
|
if (CollectionUtils.isEmpty(elements)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
@@ -1,16 +1,21 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||||
|
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||||
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -18,6 +23,9 @@ import java.util.stream.Collectors;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class MapperHelper {
|
public class MapperHelper {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
|
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
|
||||||
Integer subRegLength = regOffsetToLength.get(index);
|
Integer subRegLength = regOffsetToLength.get(index);
|
||||||
if (Objects.nonNull(subRegLength)) {
|
if (Objects.nonNull(subRegLength)) {
|
||||||
@@ -40,6 +48,13 @@ public class MapperHelper {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double getThresholdMatch(List<String> natures) {
|
||||||
|
if (existDimensionValues(natures)) {
|
||||||
|
return optimizationConfig.getDimensionValueThresholdConfig();
|
||||||
|
}
|
||||||
|
return optimizationConfig.getMetricDimensionThresholdConfig();
|
||||||
|
}
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* exist dimension values
|
* exist dimension values
|
||||||
* @param natures
|
* @param natures
|
||||||
@@ -47,16 +62,7 @@ public class MapperHelper {
|
|||||||
*/
|
*/
|
||||||
public boolean existDimensionValues(List<String> natures) {
|
public boolean existDimensionValues(List<String> natures) {
|
||||||
for (String nature : natures) {
|
for (String nature : natures) {
|
||||||
if (NatureHelper.isDimensionValueDataSetId(nature)) {
|
if (NatureHelper.isDimensionValueViewId(nature)) {
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean existTerms(List<String> natures) {
|
|
||||||
for (String nature : natures) {
|
|
||||||
if (NatureHelper.isTermNature(nature)) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,4 +81,34 @@ public class MapperHelper {
|
|||||||
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
|
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
|
||||||
detectSegment.length());
|
detectSegment.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Set<Long> getViewIds(Long viewId, Agent agent) {
|
||||||
|
|
||||||
|
Set<Long> detectViewIds = new HashSet<>();
|
||||||
|
if (Objects.nonNull(agent)) {
|
||||||
|
detectViewIds = agent.getViewIds(null);
|
||||||
|
}
|
||||||
|
//contains all
|
||||||
|
if (Agent.containsAllModel(detectViewIds)) {
|
||||||
|
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||||
|
Set<Long> result = new HashSet<>();
|
||||||
|
result.add(viewId);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
return new HashSet<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Objects.nonNull(detectViewIds)) {
|
||||||
|
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
|
||||||
|
if (detectViewIds.contains(viewId)) {
|
||||||
|
Set<Long> result = new HashSet<>();
|
||||||
|
result.add(viewId);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return detectViewIds;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -14,6 +13,6 @@ import java.util.Set;
|
|||||||
*/
|
*/
|
||||||
public interface MatchStrategy<T> {
|
public interface MatchStrategy<T> {
|
||||||
|
|
||||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
@Builder
|
@Builder
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import java.io.Serializable;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
public class ModelWithSemanticType implements Serializable {
|
public class ModelWithSemanticType implements Serializable {
|
||||||
@@ -1,59 +1,56 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class QueryFilterMapper extends BaseMapper {
|
public class QueryFilterMapper implements SchemaMapper {
|
||||||
|
|
||||||
private double similarity = 1.0;
|
private double similarity = 1.0;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void map(QueryContext queryContext) {
|
||||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
Long viewId = queryContext.getViewId();
|
||||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
if (viewId == null || viewId <= 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
||||||
for (Long dataSetId : dataSetIds) {
|
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId);
|
if (schemaElementMatches == null) {
|
||||||
if (schemaElementMatches == null) {
|
schemaElementMatches = Lists.newArrayList();
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||||
schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches);
|
|
||||||
}
|
|
||||||
addValueSchemaElementMatch(dataSetId, queryContext, schemaElementMatches);
|
|
||||||
}
|
}
|
||||||
|
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
|
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||||
if (!viewIds.contains(entry.getKey())) {
|
if (!entry.getKey().equals(modelId)) {
|
||||||
entry.getValue().clear();
|
entry.getValue().clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
|
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
|
||||||
List<SchemaElementMatch> candidateElementMatches) {
|
List<SchemaElementMatch> candidateElementMatches) {
|
||||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
return;
|
return candidateElementMatches;
|
||||||
}
|
}
|
||||||
for (QueryFilter filter : queryFilters.getFilters()) {
|
for (QueryFilter filter : queryFilters.getFilters()) {
|
||||||
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
||||||
@@ -64,7 +61,7 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
.name(String.valueOf(filter.getValue()))
|
.name(String.valueOf(filter.getValue()))
|
||||||
.type(SchemaElementType.VALUE)
|
.type(SchemaElementType.VALUE)
|
||||||
.bizName(filter.getBizName())
|
.bizName(filter.getBizName())
|
||||||
.dataSet(dataSetId)
|
.view(queryContext.getViewId())
|
||||||
.build();
|
.build();
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
@@ -75,7 +72,7 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
.build();
|
.build();
|
||||||
candidateElementMatches.add(schemaElementMatch);
|
candidateElementMatches.add(schemaElementMatch);
|
||||||
}
|
}
|
||||||
queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
|
return candidateElementMatches;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
||||||
@@ -1,23 +1,22 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.knowledge.SearchService;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SearchMatchStrategy encapsulates a concrete matching algorithm
|
* SearchMatchStrategy encapsulates a concrete matching algorithm
|
||||||
@@ -29,11 +28,11 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
private static final int SEARCH_SIZE = 3;
|
private static final int SEARCH_SIZE = 3;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private KnowledgeBaseService knowledgeBaseService;
|
private KnowledgeService knowledgeService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectViewIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||||
|
|
||||||
@@ -57,10 +56,10 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
String detectSegment = text.substring(detectIndex);
|
String detectSegment = text.substring(detectIndex);
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
SearchService.SEARCH_SIZE, detectViewIds);
|
||||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
detectSegment, SEARCH_SIZE, detectViewIds);
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||||
@@ -94,7 +93,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
String detectSegment, int offset) {
|
String detectSegment, int offset) {
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLMProxy based on langchain4j Java version.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class JavaLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkip(QueryContext queryContext) {
|
||||||
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
|
if (Objects.isNull(chatLanguageModel)) {
|
||||||
|
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||||
|
|
||||||
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||||
|
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||||
|
String modelName = llmReq.getSchema().getViewName();
|
||||||
|
LLMResp result = sqlGeneration.generation(llmReq, viewId);
|
||||||
|
result.setQuery(llmReq.getQueryText());
|
||||||
|
result.setModelName(modelName);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||||
|
|
||||||
|
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
||||||
|
|
||||||
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
|
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||||
|
functionReq.getPluginConfigs());
|
||||||
|
keyPipelineLog.info("functionCallPrompt:{}", functionCallPrompt);
|
||||||
|
String response = chatLanguageModel.generate(functionCallPrompt);
|
||||||
|
keyPipelineLog.info("functionCall response:{}", response);
|
||||||
|
return OutputFormat.functionCallParse(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||||
|
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||||
|
* could be used.
|
||||||
|
*/
|
||||||
|
public interface LLMProxy {
|
||||||
|
|
||||||
|
boolean isSkip(QueryContext queryContext);
|
||||||
|
|
||||||
|
LLMResp query2sql(LLMReq llmReq, Long viewId);
|
||||||
|
|
||||||
|
FunctionResp requestFunction(FunctionReq functionReq);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSON;
|
||||||
|
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class PythonLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkip(QueryContext queryContext) {
|
||||||
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||||
|
log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
|
try {
|
||||||
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
|
||||||
|
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||||
|
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||||
|
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||||
|
LLMResp.class);
|
||||||
|
|
||||||
|
LLMResp llmResp = responseEntity.getBody();
|
||||||
|
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||||
|
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||||
|
keyPipelineLog.info("LLMResp:{}", llmResp);
|
||||||
|
|
||||||
|
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||||
|
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||||
|
}
|
||||||
|
return llmResp;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("requestLLM error", e);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||||
|
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||||
|
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
|
||||||
|
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||||
|
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||||
|
try {
|
||||||
|
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||||
|
keyPipelineLog.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||||
|
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||||
|
FunctionResp.class);
|
||||||
|
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
||||||
|
System.currentTimeMillis() - startTime);
|
||||||
|
keyPipelineLog.info("response:{}", responseEntity.getBody());
|
||||||
|
return responseEntity.getBody();
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("requestFunction error", e);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
|
|
||||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -50,54 +50,38 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
//1. entity queryType
|
//1. entity queryType
|
||||||
Long dataSetId = parseInfo.getDataSetId();
|
Long viewId = parseInfo.getViewId();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
//If all the fields in the SELECT statement are of tag type.
|
||||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||||
|
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
|
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||||
.anyMatch(whereFilterByTimeFields::contains)) {
|
|
||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
}
|
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
|
||||||
selectFields.addAll(whereFields);
|
|
||||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
|
|
||||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
|
||||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
//If all the fields in the SELECT/WHERE statement are of tag type.
|
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||||
if (CollectionUtils.isNotEmpty(tags)
|
return QueryType.TAG;
|
||||||
&& tags.containsAll(selectWhereFilterByTimeFields)) {
|
|
||||||
return QueryType.DETAIL;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//2. metric queryType
|
//2. metric queryType
|
||||||
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
|
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
return QueryType.METRIC;
|
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||||
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
|
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||||
|
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||||
|
if (containMetric) {
|
||||||
|
return QueryType.METRIC;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
|
||||||
List<String> selectAndWhereFilterByTimeFields = whereFields
|
|
||||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
return selectAndWhereFilterByTimeFields;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
|
||||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
|
||||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
|
||||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
|
||||||
return selectFields.stream().anyMatch(metricNameSet::contains);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This checker can be used by semantic parsers to check if query intent
|
||||||
|
* has already been satisfied by current candidate queries. If so, current
|
||||||
|
* parser could be skipped.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SatisfactionChecker {
|
||||||
|
|
||||||
|
// check all the parse info in candidate
|
||||||
|
public static boolean isSkip(QueryContext queryContext) {
|
||||||
|
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||||
|
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
|
||||||
|
int queryTextLength = queryText.replaceAll(" ", "").length();
|
||||||
|
double degree = semanticParseInfo.getScore() / queryTextLength;
|
||||||
|
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||||
|
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
|
||||||
|
if (degree < optimizationConfig.getLongTextThreshold()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (degree < optimizationConfig.getShortTextThreshold()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
log.info("queryMode:{}, degree:{}, parse info:{}",
|
||||||
|
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A semantic parser understands user queries and extracts semantic information.
|
||||||
|
* It could leverage either rule-based or LLM-based approach to identify query intent
|
||||||
|
* and extract related semantic items from the query.
|
||||||
|
*/
|
||||||
|
public interface SemanticParser {
|
||||||
|
|
||||||
|
void parse(QueryContext queryContext, ChatContext chatContext);
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.plugin;
|
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||||
|
|
||||||
public enum ParseMode {
|
public enum ParseMode {
|
||||||
|
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||||
|
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||||
|
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||||
|
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||||
|
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PluginParser defines the basic process and common methods for recalling plugins.
|
||||||
|
*/
|
||||||
|
public abstract class PluginParser implements SemanticParser {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||||
|
if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||||
|
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!checkPreCondition(queryContext)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PluginRecallResult pluginRecallResult = recallPlugin(queryContext);
|
||||||
|
if (pluginRecallResult == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
buildQuery(queryContext, pluginRecallResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean checkPreCondition(QueryContext queryContext);
|
||||||
|
|
||||||
|
public abstract PluginRecallResult recallPlugin(QueryContext queryContext);
|
||||||
|
|
||||||
|
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||||
|
Plugin plugin = pluginRecallResult.getPlugin();
|
||||||
|
Set<Long> viewIds = pluginRecallResult.getViewIds();
|
||||||
|
if (plugin.isContainsAllModel()) {
|
||||||
|
viewIds = Sets.newHashSet(-1L);
|
||||||
|
}
|
||||||
|
for (Long viewId : viewIds) {
|
||||||
|
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||||
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
|
||||||
|
queryContext, pluginRecallResult.getDistance());
|
||||||
|
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||||
|
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||||
|
pluginQuery.setParseInfo(semanticParseInfo);
|
||||||
|
queryContext.getCandidateQueries().add(pluginQuery);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||||
|
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
|
||||||
|
QueryContext queryContext, double distance) {
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||||
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
|
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
|
||||||
|
viewId = plugin.getViewList().get(0);
|
||||||
|
}
|
||||||
|
if (schemaElementMatches == null) {
|
||||||
|
schemaElementMatches = Lists.newArrayList();
|
||||||
|
}
|
||||||
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
|
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
|
||||||
|
Map<String, Object> properties = new HashMap<>();
|
||||||
|
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||||
|
pluginParseResult.setPlugin(plugin);
|
||||||
|
pluginParseResult.setQueryFilters(queryFilters);
|
||||||
|
pluginParseResult.setDistance(distance);
|
||||||
|
pluginParseResult.setQueryText(queryContext.getQueryText());
|
||||||
|
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||||
|
properties.put("type", "plugin");
|
||||||
|
properties.put("name", plugin.getName());
|
||||||
|
semanticParseInfo.setProperties(properties);
|
||||||
|
semanticParseInfo.setScore(distance);
|
||||||
|
fillSemanticParseInfo(semanticParseInfo);
|
||||||
|
return semanticParseInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||||
|
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||||
|
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||||
|
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||||
|
.forEach(schemaElementMatch -> {
|
||||||
|
QueryFilter queryFilter = new QueryFilter();
|
||||||
|
queryFilter.setValue(schemaElementMatch.getWord());
|
||||||
|
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||||
|
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||||
|
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||||
|
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||||
|
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.server.plugin.ParseMode;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.headless.core.chat.parser.llm.PythonLLMProxy;
|
|
||||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
@@ -28,42 +28,44 @@ import java.util.stream.Collectors;
|
|||||||
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
|
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
public class EmbeddingRecallParser extends PluginParser {
|
||||||
|
|
||||||
public boolean checkPreCondition(ChatParseContext chatParseContext) {
|
@Override
|
||||||
|
public boolean checkPreCondition(QueryContext queryContext) {
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
List<Plugin> plugins = getPluginList(chatParseContext);
|
List<Plugin> plugins = getPluginList(queryContext);
|
||||||
return !CollectionUtils.isEmpty(plugins);
|
return !CollectionUtils.isEmpty(plugins);
|
||||||
}
|
}
|
||||||
|
|
||||||
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
|
@Override
|
||||||
String text = chatParseContext.getQueryText();
|
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||||
|
String text = queryContext.getQueryText();
|
||||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
List<Plugin> plugins = getPluginList(chatParseContext);
|
List<Plugin> plugins = getPluginList(queryContext);
|
||||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||||
if (plugin == null) {
|
if (plugin == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
|
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||||
log.info("embedding plugin resolve: {}", pair);
|
log.info("embedding plugin resolve: {}", pair);
|
||||||
if (pair.getLeft()) {
|
if (pair.getLeft()) {
|
||||||
Set<Long> dataSetList = pair.getRight();
|
Set<Long> viewList = pair.getRight();
|
||||||
if (CollectionUtils.isEmpty(dataSetList)) {
|
if (CollectionUtils.isEmpty(viewList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double distance = embeddingRetrieval.getDistance();
|
double distance = embeddingRetrieval.getDistance();
|
||||||
double score = chatParseContext.getQueryText().length() * (1 - distance);
|
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||||
return PluginRecallResult.builder()
|
return PluginRecallResult.builder()
|
||||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@Data
|
||||||
|
public class FunctionCallConfig {
|
||||||
|
@Value("${functionCall.url:}")
|
||||||
|
private String url;
|
||||||
|
|
||||||
|
@Value("${funtionCall.plugin.select.path:/plugin_selection}")
|
||||||
|
private String pluginSelectPath;
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user