mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
168 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40ea6a9396 | ||
|
|
78d724ea83 | ||
|
|
eadbdc4e30 | ||
|
|
a909493414 | ||
|
|
aa86fc9275 | ||
|
|
e36060eae4 | ||
|
|
3893e897cb | ||
|
|
617cd87a48 | ||
|
|
042a610231 | ||
|
|
b555beae21 | ||
|
|
ba01cdb9bc | ||
|
|
e610dd8246 | ||
|
|
01bc4dcacf | ||
|
|
c224b81160 | ||
|
|
bfd0e040da | ||
|
|
f50a3157d5 | ||
|
|
61316e939c | ||
|
|
fab1bac50c | ||
|
|
d8043c356f | ||
|
|
e95a528219 | ||
|
|
16643e8d75 | ||
|
|
417a43dee8 | ||
|
|
4a22fdf452 | ||
|
|
16afbc6945 | ||
|
|
b8831317e9 | ||
|
|
fc5ff01eca | ||
|
|
d10801ef38 | ||
|
|
33240cc382 | ||
|
|
3317f1b7ec | ||
|
|
b85778babd | ||
|
|
699a33b1c1 | ||
|
|
fdb69547e6 | ||
|
|
39158d6877 | ||
|
|
329ad327b0 | ||
|
|
9600456bae | ||
|
|
74d0ec2b23 | ||
|
|
8a342eb32a | ||
|
|
e801c448be | ||
|
|
da5e7b9b75 | ||
|
|
75853a8e9e | ||
|
|
2546d1c0e1 | ||
|
|
0c4c6d83ef | ||
|
|
4d4922d269 | ||
|
|
1004f71ba4 | ||
|
|
c13a0e672c | ||
|
|
491c76368c | ||
|
|
2c1c443b3e | ||
|
|
f29b1854ba | ||
|
|
7f15bacca4 | ||
|
|
df975b231d | ||
|
|
24b442baef | ||
|
|
31f8c1df35 | ||
|
|
26aefceb04 | ||
|
|
954c67c947 | ||
|
|
fdfad515dd | ||
|
|
c398ac1a84 | ||
|
|
aae3d6b297 | ||
|
|
923c65b2f9 | ||
|
|
22775343f4 | ||
|
|
d9533c53ea | ||
|
|
841db25198 | ||
|
|
922201c181 | ||
|
|
48fb01f6bc | ||
|
|
9d6f96e6d4 | ||
|
|
42a6f61456 | ||
|
|
163e782f51 | ||
|
|
be158a1776 | ||
|
|
c12f5d23f0 | ||
|
|
7ec77c7d23 | ||
|
|
f04cc28f25 | ||
|
|
b28eb637c8 | ||
|
|
97c767a45b | ||
|
|
7afa42b4bc | ||
|
|
a375a922c2 | ||
|
|
add74b9589 | ||
|
|
566321895e | ||
|
|
f154c2a2b4 | ||
|
|
9c6bd7cf19 | ||
|
|
20c8456705 | ||
|
|
c154f476cb | ||
|
|
1ef1aa53a3 | ||
|
|
08c184c7b0 | ||
|
|
36edc0c1b4 | ||
|
|
026cf2056d | ||
|
|
fc82350af5 | ||
|
|
6b5d84a13f | ||
|
|
3ba9073a1b | ||
|
|
cbf38ed785 | ||
|
|
f017f41201 | ||
|
|
b40670b0e3 | ||
|
|
7af5afc3eb | ||
|
|
0abbd83f51 | ||
|
|
3e77fc3069 | ||
|
|
b019f4d9bb | ||
|
|
90f9da162e | ||
|
|
a06a1fa898 | ||
|
|
dfb8e3a427 | ||
|
|
d4eecc1bf8 | ||
|
|
71c491a80d | ||
|
|
93c3ce1631 | ||
|
|
c181ce6945 | ||
|
|
07e6924cfd | ||
|
|
b2beecb5b8 | ||
|
|
eb08667d90 | ||
|
|
74b89a9430 | ||
|
|
7707179faa | ||
|
|
af103f3aa3 | ||
|
|
7f65057a0f | ||
|
|
a6818fb6ff | ||
|
|
e7d654f150 | ||
|
|
e29ecec0c9 | ||
|
|
dcc1f26542 | ||
|
|
3436b36552 | ||
|
|
4322ae42ac | ||
|
|
e9c7237794 | ||
|
|
3a5349c916 | ||
|
|
1e93282c9f | ||
|
|
9c8039c499 | ||
|
|
61da52650c | ||
|
|
87a60eeba2 | ||
|
|
052e217c8c | ||
|
|
bbad302efd | ||
|
|
ed54d7bae3 | ||
|
|
0a6160272b | ||
|
|
ad8079e058 | ||
|
|
8eef11f342 | ||
|
|
b55b4c130e | ||
|
|
0408f0fe9a | ||
|
|
062f7340e5 | ||
|
|
72bd79fe73 | ||
|
|
602b9547b8 | ||
|
|
ade96c3adc | ||
|
|
023e84c420 | ||
|
|
0858c13365 | ||
|
|
7acb48da0e | ||
|
|
52fea5311d | ||
|
|
d72166944c | ||
|
|
e7f13572d7 | ||
|
|
af1c560cc4 | ||
|
|
49f0a4dc1d | ||
|
|
31f1fc315d | ||
|
|
afa8fd74a6 | ||
|
|
56b0f35250 | ||
|
|
af6c8cdbda | ||
|
|
9e69002d70 | ||
|
|
cd727663a5 | ||
|
|
d0289a3243 | ||
|
|
327bab015e | ||
|
|
f788249b1a | ||
|
|
6c8ebdfe1a | ||
|
|
b706c4efb4 | ||
|
|
cf2b4bfb5c | ||
|
|
40c86810bb | ||
|
|
5ab1cade0a | ||
|
|
e0955c0618 | ||
|
|
125598bd6e | ||
|
|
0cbfe473dd | ||
|
|
090abbceed | ||
|
|
0a1f446fb8 | ||
|
|
fa38e37be3 | ||
|
|
7b580b7c94 | ||
|
|
2631352c30 | ||
|
|
f7914ff6f4 | ||
|
|
ab077df36d | ||
|
|
150d67f903 | ||
|
|
187dcacbe7 | ||
|
|
a194822cdd | ||
|
|
6bbc0a2cb4 |
31
.github/workflows/ci.yml
vendored
Normal file
31
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
name: supersonic CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up JDK 8
|
||||||
|
uses: actions/setup-java@v2
|
||||||
|
with:
|
||||||
|
java-version: '8'
|
||||||
|
distribution: 'adopt'
|
||||||
|
- name: Cache Maven packages
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/.m2
|
||||||
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
|
restore-keys: ${{ runner.os }}-m2
|
||||||
|
- name: Build with Maven
|
||||||
|
run: mvn -B package --file pom.xml
|
||||||
|
- name: Test with Maven
|
||||||
|
run: mvn test
|
||||||
16
CHANGELOG.md
16
CHANGELOG.md
@@ -4,6 +4,22 @@
|
|||||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||||
compatibility issues with previous versions.
|
compatibility issues with previous versions.
|
||||||
|
|
||||||
|
## SuperSonic [0.8.6] - 2024-02-23
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- support view abstraction to Headless.
|
||||||
|
- add the Metric API to Headless and optimizing the Headless API.
|
||||||
|
- add integration tests to Headless.
|
||||||
|
- add TimeCorrector to Chat.
|
||||||
|
|
||||||
|
## SuperSonic [0.8.4] - 2024-01-19
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- support creating derived metrics.
|
||||||
|
- Support creating metrics using three methods: by measure, metric, and field expressions.
|
||||||
|
- added support for postgresql data source.
|
||||||
|
- code adjustment and abstract optimization for chat and headless.
|
||||||
|
|
||||||
## SuperSonic [0.8.2] - 2023-12-18
|
## SuperSonic [0.8.2] - 2023-12-18
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
[中文介绍](README_CN.md) | [文档中心](https://github.com/tencentmusic/supersonic/wiki)
|
[中文介绍](README_CN.md) | [文档中心](https://github.com/tencentmusic/supersonic/wiki)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
# SuperSonic (超音数)
|
# SuperSonic (超音数)
|
||||||
|
|
||||||
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ if [ $? -ne 0 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
#2. move package to build
|
#2. move package to build
|
||||||
cp $baseDir/../launchers/semantic/target/*.tar.gz ${buildDir}/supersonic-semantic.tar.gz
|
cp $baseDir/../launchers/headless/target/*.tar.gz ${buildDir}/supersonic-headless.tar.gz
|
||||||
cp $baseDir/../launchers/chat/target/*.tar.gz ${buildDir}/supersonic-chat.tar.gz
|
cp $baseDir/../launchers/chat/target/*.tar.gz ${buildDir}/supersonic-chat.tar.gz
|
||||||
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic-standalone.tar.gz
|
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic-standalone.tar.gz
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ fi
|
|||||||
cd $buildDir
|
cd $buildDir
|
||||||
tar xvf supersonic-webapp.tar.gz
|
tar xvf supersonic-webapp.tar.gz
|
||||||
mv supersonic-webapp webapp
|
mv supersonic-webapp webapp
|
||||||
cp -fr webapp ../../launchers/semantic/target/classes
|
cp -fr webapp ../../launchers/headless/target/classes
|
||||||
cp -fr webapp ../../launchers/chat/target/classes
|
cp -fr webapp ../../launchers/chat/target/classes
|
||||||
cp -fr webapp ../../launchers/standalone/target/classes
|
cp -fr webapp ../../launchers/standalone/target/classes
|
||||||
rm -fr ${buildDir}/webapp
|
rm -fr ${buildDir}/webapp
|
||||||
@@ -55,4 +55,4 @@ fi
|
|||||||
rm -fr $runtimeDir/supersonic*
|
rm -fr $runtimeDir/supersonic*
|
||||||
moveAllToRuntime
|
moveAllToRuntime
|
||||||
setEnvToWeb chat
|
setEnvToWeb chat
|
||||||
setEnvToWeb semantic
|
setEnvToWeb headless
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ runtimeDir=$baseDir/../runtime
|
|||||||
buildDir=$baseDir/build
|
buildDir=$baseDir/build
|
||||||
|
|
||||||
readonly CHAT_APP_NAME="supersonic_chat"
|
readonly CHAT_APP_NAME="supersonic_chat"
|
||||||
readonly SEMANTIC_APP_NAME="supersonic_semantic"
|
readonly HEADLESS_APP_NAME="supersonic_headless"
|
||||||
readonly PYLLM_APP_NAME="supersonic_pyllm"
|
readonly PYLLM_APP_NAME="supersonic_pyllm"
|
||||||
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
||||||
readonly CHAT_SERVICE="chat"
|
readonly CHAT_SERVICE="chat"
|
||||||
readonly SEMANTIC_SERVICE="semantic"
|
readonly HEADLESS_SERVICE="headless"
|
||||||
readonly PYLLM_SERVICE="pyllm"
|
readonly PYLLM_SERVICE="pyllm"
|
||||||
readonly STANDALONE_SERVICE="standalone"
|
readonly STANDALONE_SERVICE="standalone"
|
||||||
readonly PYLLM_HOST="127.0.0.1"
|
readonly PYLLM_HOST="127.0.0.1"
|
||||||
@@ -46,7 +46,7 @@ function moveAllToRuntime {
|
|||||||
mv ${buildDir}/supersonic-webapp ${buildDir}/webapp
|
mv ${buildDir}/supersonic-webapp ${buildDir}/webapp
|
||||||
|
|
||||||
moveToRuntime chat
|
moveToRuntime chat
|
||||||
moveToRuntime semantic
|
moveToRuntime headless
|
||||||
moveToRuntime standalone
|
moveToRuntime standalone
|
||||||
rm -fr ${buildDir}/webapp
|
rm -fr ${buildDir}/webapp
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ cd $baseDir
|
|||||||
function setMainClass {
|
function setMainClass {
|
||||||
if [ "$service" == $CHAT_SERVICE ]; then
|
if [ "$service" == $CHAT_SERVICE ]; then
|
||||||
main_class="com.tencent.supersonic.ChatLauncher"
|
main_class="com.tencent.supersonic.ChatLauncher"
|
||||||
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
elif [ "$service" == $HEADLESS_SERVICE ]; then
|
||||||
main_class="com.tencent.supersonic.SemanticLauncher"
|
main_class="com.tencent.supersonic.HeadlessLauncher"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
setMainClass
|
setMainClass
|
||||||
@@ -42,8 +42,8 @@ setMainClass
|
|||||||
function setAppName {
|
function setAppName {
|
||||||
if [ "$service" == $CHAT_SERVICE ]; then
|
if [ "$service" == $CHAT_SERVICE ]; then
|
||||||
app_name=$CHAT_APP_NAME
|
app_name=$CHAT_APP_NAME
|
||||||
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
elif [ "$service" == $HEADLESS_SERVICE ]; then
|
||||||
app_name=$SEMANTIC_APP_NAME
|
app_name=$HEADLESS_APP_NAME
|
||||||
elif [ "$service" == $PYLLM_SERVICE ]; then
|
elif [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
app_name=$PYLLM_APP_NAME
|
app_name=$PYLLM_APP_NAME
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -24,4 +24,13 @@ public class AuthenticationConfig {
|
|||||||
@Value("${authentication.token.http.header.key:Authorization}")
|
@Value("${authentication.token.http.header.key:Authorization}")
|
||||||
private String tokenHttpHeaderKey;
|
private String tokenHttpHeaderKey;
|
||||||
|
|
||||||
|
@Value("${authentication.app.appId:appId}")
|
||||||
|
private String appId;
|
||||||
|
|
||||||
|
@Value("${authentication.app.timestamp:timestamp}")
|
||||||
|
private String timestamp;
|
||||||
|
|
||||||
|
@Value("${authentication.app.signature:signature}")
|
||||||
|
private String signature;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package com.tencent.supersonic.auth.api.authentication.pojo;
|
package com.tencent.supersonic.auth.api.authentication.pojo;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
public class Organization {
|
public class Organization {
|
||||||
|
|
||||||
private String id;
|
private String id;
|
||||||
|
|||||||
@@ -24,10 +24,19 @@ public class User {
|
|||||||
return new User(id, name, displayName, email, isAdmin);
|
return new User(id, name, displayName, email, isAdmin);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static User get(Long id, String name) {
|
||||||
|
return new User(id, name, name, name, 0);
|
||||||
|
}
|
||||||
|
|
||||||
public static User getFakeUser() {
|
public static User getFakeUser() {
|
||||||
return new User(1L, "admin", "admin", "admin@email", 1);
|
return new User(1L, "admin", "admin", "admin@email", 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static User getAppUser(int appId) {
|
||||||
|
String name = String.format("app_%s", appId);
|
||||||
|
return new User(1L, name, name, "", 1);
|
||||||
|
}
|
||||||
|
|
||||||
public String getDisplayName() {
|
public String getDisplayName() {
|
||||||
return StringUtils.isBlank(displayName) ? name : displayName;
|
return StringUtils.isBlank(displayName) ? name : displayName;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,15 @@ import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||||
|
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
public interface UserService {
|
public interface UserService {
|
||||||
|
|
||||||
|
User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse);
|
||||||
|
|
||||||
List<String> getUserNames();
|
List<String> getUserNames();
|
||||||
|
|
||||||
List<User> getUserList();
|
List<User> getUserList();
|
||||||
|
|||||||
@@ -2,6 +2,11 @@ package com.tencent.supersonic.auth.api.authentication.utils;
|
|||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
|
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
|
||||||
|
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||||
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
@@ -14,7 +19,14 @@ public final class UserHolder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
|
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
|
||||||
return REPO.findUser(request, response);
|
User user = REPO.findUser(request, response);
|
||||||
|
SysParameterService sysParameterService = ContextUtils.getBean(SysParameterService.class);
|
||||||
|
SysParameter sysParameter = sysParameterService.getSysParameter();
|
||||||
|
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
|
||||||
|
&& sysParameter.getAdmins().contains(user.getName())) {
|
||||||
|
user.setIsAdmin(1);
|
||||||
|
}
|
||||||
|
return user;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,17 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Organization> getOrganizationTree() {
|
public List<Organization> getOrganizationTree() {
|
||||||
return Lists.newArrayList();
|
Organization superSonic = new Organization("1", "0",
|
||||||
|
"SuperSonic", "SuperSonic", Lists.newArrayList(), true);
|
||||||
|
Organization hr = new Organization("2", "1",
|
||||||
|
"Hr", "SuperSonic/Hr", Lists.newArrayList(), false);
|
||||||
|
Organization sales = new Organization("3", "1",
|
||||||
|
"Sales", "SuperSonic/Sales", Lists.newArrayList(), false);
|
||||||
|
Organization marketing = new Organization("4", "1",
|
||||||
|
"Marketing", "SuperSonic/Marketing", Lists.newArrayList(), false);
|
||||||
|
List<Organization> subOrganization = Lists.newArrayList(hr, sales, marketing);
|
||||||
|
superSonic.setSubOrganizations(subOrganization);
|
||||||
|
return Lists.newArrayList(superSonic);
|
||||||
}
|
}
|
||||||
|
|
||||||
private User convert(UserDO userDO) {
|
private User convert(UserDO userDO) {
|
||||||
|
|||||||
@@ -5,16 +5,17 @@ import com.tencent.supersonic.auth.api.authentication.constant.UserConstants;
|
|||||||
import com.tencent.supersonic.auth.authentication.service.UserServiceImpl;
|
import com.tencent.supersonic.auth.authentication.service.UserServiceImpl;
|
||||||
import com.tencent.supersonic.auth.authentication.utils.UserTokenUtils;
|
import com.tencent.supersonic.auth.authentication.utils.UserTokenUtils;
|
||||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import org.apache.catalina.connector.RequestFacade;
|
import org.apache.catalina.connector.RequestFacade;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.apache.tomcat.util.http.MimeHeaders;
|
import org.apache.tomcat.util.http.MimeHeaders;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
|
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
|
||||||
import org.springframework.web.servlet.HandlerInterceptor;
|
import org.springframework.web.servlet.HandlerInterceptor;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
||||||
@@ -58,6 +59,11 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
|||||||
return "true".equalsIgnoreCase(internal);
|
return "true".equalsIgnoreCase(internal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected boolean isAppRequest(HttpServletRequest request) {
|
||||||
|
String appId = request.getHeader(authenticationConfig.getAppId());
|
||||||
|
return StringUtils.isNotBlank(appId);
|
||||||
|
}
|
||||||
|
|
||||||
protected void reflectSetparam(HttpServletRequest request, String key, String value) {
|
protected void reflectSetparam(HttpServletRequest request, String key, String value) {
|
||||||
try {
|
try {
|
||||||
if (request instanceof StandardMultipartHttpServletRequest) {
|
if (request instanceof StandardMultipartHttpServletRequest) {
|
||||||
|
|||||||
@@ -10,12 +10,12 @@ import com.tencent.supersonic.common.pojo.exception.AccessException;
|
|||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||||
import com.tencent.supersonic.common.util.ThreadContext;
|
import com.tencent.supersonic.common.util.ThreadContext;
|
||||||
import java.lang.reflect.Method;
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.servlet.http.HttpServletResponse;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.web.method.HandlerMethod;
|
import org.springframework.web.method.HandlerMethod;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor {
|
public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor {
|
||||||
@@ -35,7 +35,10 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor
|
|||||||
setFakerUser(request);
|
setFakerUser(request);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (isAppRequest(request)) {
|
||||||
|
setFakerUser(request);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (handler instanceof HandlerMethod) {
|
if (handler instanceof HandlerMethod) {
|
||||||
HandlerMethod handlerMethod = (HandlerMethod) handler;
|
HandlerMethod handlerMethod = (HandlerMethod) handler;
|
||||||
Method method = handlerMethod.getMethod();
|
Method method = handlerMethod.getMethod();
|
||||||
|
|||||||
@@ -5,18 +5,18 @@ import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.servlet.http.HttpServletResponse;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.GetMapping;
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
import org.springframework.web.bind.annotation.PathVariable;
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
import org.springframework.web.bind.annotation.PathVariable;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
@RestController
|
@RestController
|
||||||
@RequestMapping("/api/auth/user")
|
@RequestMapping("/api/auth/user")
|
||||||
@@ -31,7 +31,7 @@ public class UserController {
|
|||||||
|
|
||||||
@GetMapping("/getCurrentUser")
|
@GetMapping("/getCurrentUser")
|
||||||
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||||
return UserHolder.findUser(httpServletRequest, httpServletResponse);
|
return userService.getCurrentUser(httpServletRequest, httpServletResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/getUserNames")
|
@GetMapping("/getUserNames")
|
||||||
|
|||||||
@@ -4,15 +4,39 @@ import com.tencent.supersonic.auth.api.authentication.pojo.Organization;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||||
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
|
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||||
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class UserServiceImpl implements UserService {
|
public class UserServiceImpl implements UserService {
|
||||||
|
|
||||||
|
private SysParameterService sysParameterService;
|
||||||
|
|
||||||
|
public UserServiceImpl(SysParameterService sysParameterService) {
|
||||||
|
this.sysParameterService = sysParameterService;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||||
|
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||||
|
if (user != null) {
|
||||||
|
SysParameter sysParameter = sysParameterService.getSysParameter();
|
||||||
|
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
|
||||||
|
&& sysParameter.getAdmins().contains(user.getName())) {
|
||||||
|
user.setIsAdmin(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return user;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> getUserNames() {
|
public List<String> getUserNames() {
|
||||||
return ComponentFactory.getUserAdaptor().getUserNames();
|
return ComponentFactory.getUserAdaptor().getUserNames();
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.tencent.supersonic</groupId>
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
<artifactId>semantic-api</artifactId>
|
<artifactId>headless-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>compile</scope>
|
<scope>compile</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.component;
|
|
||||||
|
|
||||||
import com.github.pagehelper.PageInfo;
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A semantic layer provides a simplified and consistent view of data from multiple sources.
|
|
||||||
* It abstracts away the complexity of the underlying data sources and provides a unified view
|
|
||||||
* of the data that is easier to understand and use.
|
|
||||||
* <p>
|
|
||||||
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
|
|
||||||
* Implementations of this interface should provide concrete implementations that interact with the
|
|
||||||
* underlying data sources and return results in a consistent format. Or it can be implemented
|
|
||||||
* as proxy to a remote semantic service.
|
|
||||||
* </p>
|
|
||||||
*/
|
|
||||||
public interface SemanticInterpreter {
|
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
|
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
|
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
|
||||||
|
|
||||||
List<ModelSchema> getModelSchema();
|
|
||||||
|
|
||||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
|
||||||
|
|
||||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
|
||||||
|
|
||||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
|
||||||
|
|
||||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
|
||||||
|
|
||||||
List<DomainResp> getDomainList(User user);
|
|
||||||
|
|
||||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
|
||||||
|
|
||||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
|
||||||
|
|
||||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class QueryContext {
|
|
||||||
|
|
||||||
private QueryReq request;
|
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
|
||||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
|
||||||
|
|
||||||
public QueryContext(QueryReq request) {
|
|
||||||
this.request = request;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -7,25 +9,25 @@ import java.util.Set;
|
|||||||
|
|
||||||
public class SchemaMapInfo {
|
public class SchemaMapInfo {
|
||||||
|
|
||||||
private Map<Long, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
|
||||||
|
|
||||||
public Set<Long> getMatchedModels() {
|
public Set<Long> getMatchedViewInfos() {
|
||||||
return modelElementMatches.keySet();
|
return viewElementMatches.keySet();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElementMatch> getMatchedElements(Long model) {
|
public List<SchemaElementMatch> getMatchedElements(Long view) {
|
||||||
return modelElementMatches.get(model);
|
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Long, List<SchemaElementMatch>> getModelElementMatches() {
|
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
|
||||||
return modelElementMatches;
|
return viewElementMatches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setModelElementMatches(Map<Long, List<SchemaElementMatch>> modelElementMatches) {
|
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
|
||||||
this.modelElementMatches = modelElementMatches;
|
this.viewElementMatches = viewElementMatches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setMatchedElements(Long model, List<SchemaElementMatch> elementMatches) {
|
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
|
||||||
modelElementMatches.put(model, elementMatches);
|
viewElementMatches.put(view, elementMatches);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
|
||||||
|
|
||||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class SchemaModelClusterMapInfo {
|
|
||||||
|
|
||||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
|
||||||
|
|
||||||
public Set<String> getMatchedModelClusters() {
|
|
||||||
return modelElementMatches.keySet();
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
|
||||||
for (String key : modelElementMatches.keySet()) {
|
|
||||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
|
||||||
return modelElementMatches.get(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Lists.newArrayList();
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
|
||||||
return modelElementMatches.get(modelCluster);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
|
||||||
return modelElementMatches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
|
||||||
if (CollectionUtils.isEmpty(modelIds)) {
|
|
||||||
return modelElementMatches;
|
|
||||||
}
|
|
||||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
|
||||||
for (String key : modelElementMatches.keySet()) {
|
|
||||||
for (Long modelId : modelIds) {
|
|
||||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
|
||||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return modelElementMatchesFiltered;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
|
||||||
this.modelElementMatches = modelElementMatches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
|
||||||
modelElementMatches.put(modelCluster, elementMatches);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -26,7 +26,7 @@ public class SemanticParseInfo {
|
|||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode;
|
private String queryMode;
|
||||||
private ModelCluster model = new ModelCluster();
|
private SchemaElement view;
|
||||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||||
private SchemaElement entity;
|
private SchemaElement entity;
|
||||||
@@ -44,20 +44,6 @@ public class SemanticParseInfo {
|
|||||||
private SqlInfo sqlInfo = new SqlInfo();
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
private QueryType queryType = QueryType.ID;
|
private QueryType queryType = QueryType.ID;
|
||||||
|
|
||||||
public String getModelClusterKey() {
|
|
||||||
if (model == null) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return model.getKey();
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getModelName() {
|
|
||||||
if (model == null) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return model.getName();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -86,27 +72,15 @@ public class SemanticParseInfo {
|
|||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<Long, Integer> getModelElementCountMap() {
|
public Long getViewId() {
|
||||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
if (view == null) {
|
||||||
elementMatches.stream().filter(element -> element.getElement().getModel() != null)
|
return null;
|
||||||
.forEach(element -> {
|
}
|
||||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
return view.getView();
|
||||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
|
||||||
});
|
|
||||||
return elementCountMap;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getModelId() {
|
public SchemaElement getModel() {
|
||||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
return view;
|
||||||
Long modelId = -1L;
|
|
||||||
int maxCnt = 0;
|
|
||||||
for (Long model : elementCountMap.keySet()) {
|
|
||||||
if (elementCountMap.get(model) > maxCnt) {
|
|
||||||
maxCnt = elementCountMap.get(model);
|
|
||||||
modelId = model;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return modelId;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
import org.springframework.util.CollectionUtils;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
public class SemanticSchema implements Serializable {
|
public class SemanticSchema implements Serializable {
|
||||||
|
|
||||||
private List<ModelSchema> modelSchemaList;
|
private List<ViewSchema> viewSchemaList;
|
||||||
|
|
||||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
public SemanticSchema(List<ViewSchema> viewSchemaList) {
|
||||||
this.modelSchemaList = modelSchemaList;
|
this.viewSchemaList = viewSchemaList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void add(ModelSchema schema) {
|
public void add(ViewSchema schema) {
|
||||||
modelSchemaList.add(schema);
|
viewSchemaList.add(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
@@ -30,8 +30,8 @@ public class SemanticSchema implements Serializable {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = getElementsById(elementID, getEntities());
|
element = getElementsById(elementID, getEntities());
|
||||||
break;
|
break;
|
||||||
case MODEL:
|
case VIEW:
|
||||||
element = getElementsById(elementID, getModels());
|
element = getElementsById(elementID, getViews());
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = getElementsById(elementID, getMetrics());
|
element = getElementsById(elementID, getMetrics());
|
||||||
@@ -52,58 +52,29 @@ public class SemanticSchema implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
public Map<Long, String> getViewIdToName() {
|
||||||
Optional<SchemaElement> element = Optional.empty();
|
return viewSchemaList.stream()
|
||||||
|
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
|
||||||
switch (elementType) {
|
|
||||||
case ENTITY:
|
|
||||||
element = getElementsByNameOrAlias(name, getEntities());
|
|
||||||
break;
|
|
||||||
case MODEL:
|
|
||||||
element = getElementsByNameOrAlias(name, getModels());
|
|
||||||
break;
|
|
||||||
case METRIC:
|
|
||||||
element = getElementsByNameOrAlias(name, getMetrics());
|
|
||||||
break;
|
|
||||||
case DIMENSION:
|
|
||||||
element = getElementsByNameOrAlias(name, getDimensions());
|
|
||||||
break;
|
|
||||||
case VALUE:
|
|
||||||
element = getElementsByNameOrAlias(name, getDimensionValues());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if (element.isPresent()) {
|
|
||||||
return element.get();
|
|
||||||
} else {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<Long, String> getModelIdToName() {
|
|
||||||
return modelSchemaList.stream()
|
|
||||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensionValues() {
|
public List<SchemaElement> getDimensionValues() {
|
||||||
List<SchemaElement> dimensionValues = new ArrayList<>();
|
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||||
return dimensionValues;
|
return dimensionValues;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensions() {
|
public List<SchemaElement> getDimensions() {
|
||||||
List<SchemaElement> dimensions = new ArrayList<>();
|
List<SchemaElement> dimensions = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||||
return dimensions;
|
return dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
public List<SchemaElement> getDimensions(Long viewId) {
|
||||||
List<SchemaElement> dimensions = getDimensions();
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
return getElementsByModelId(modelIds, dimensions);
|
return getElementsByViewId(viewId, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getDimensions(Long id) {
|
public SchemaElement getDimension(Long id) {
|
||||||
List<SchemaElement> dimensions = getDimensions();
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||||
return dimension.orElse(null);
|
return dimension.orElse(null);
|
||||||
@@ -111,37 +82,43 @@ public class SemanticSchema implements Serializable {
|
|||||||
|
|
||||||
public List<SchemaElement> getTags() {
|
public List<SchemaElement> getTags() {
|
||||||
List<SchemaElement> tags = new ArrayList<>();
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||||
return tags;
|
return tags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
public List<SchemaElement> getTags(Long viewId) {
|
||||||
List<SchemaElement> tags = new ArrayList<>();
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
viewSchemaList.stream().filter(schemaElement ->
|
||||||
|
viewId.equals(schemaElement.getView().getView()))
|
||||||
.forEach(d -> tags.addAll(d.getTags()));
|
.forEach(d -> tags.addAll(d.getTags()));
|
||||||
return tags;
|
return tags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics() {
|
public List<SchemaElement> getMetrics() {
|
||||||
List<SchemaElement> metrics = new ArrayList<>();
|
List<SchemaElement> metrics = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
public List<SchemaElement> getMetrics(Long viewId) {
|
||||||
List<SchemaElement> metrics = getMetrics();
|
List<SchemaElement> metrics = getMetrics();
|
||||||
return getElementsByModelId(modelIds, metrics);
|
return getElementsByViewId(viewId, metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getEntities() {
|
public List<SchemaElement> getEntities() {
|
||||||
List<SchemaElement> entities = new ArrayList<>();
|
List<SchemaElement> entities = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||||
return entities;
|
return entities;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
public List<SchemaElement> getEntities(Long viewId) {
|
||||||
|
List<SchemaElement> entities = getEntities();
|
||||||
|
return getElementsByViewId(viewId, entities);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
|
||||||
return elements.stream()
|
return elements.stream()
|
||||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,32 +128,30 @@ public class SemanticSchema implements Serializable {
|
|||||||
.findFirst();
|
.findFirst();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Optional<SchemaElement> getElementsByNameOrAlias(String name, List<SchemaElement> elements) {
|
public SchemaElement getView(Long viewId) {
|
||||||
return elements.stream()
|
List<SchemaElement> views = getViews();
|
||||||
.filter(schemaElement ->
|
return getElementsById(viewId, views).orElse(null);
|
||||||
name.equals(schemaElement.getName()) || schemaElement.getAlias().contains(name)
|
|
||||||
).findFirst();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getModels() {
|
public List<SchemaElement> getViews() {
|
||||||
List<SchemaElement> models = new ArrayList<>();
|
List<SchemaElement> views = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
|
||||||
return models;
|
return views;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
public Map<String, String> getBizNameToName(Long viewId) {
|
||||||
List<SchemaElement> allElements = new ArrayList<>();
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
allElements.addAll(getDimensions(modelIds));
|
allElements.addAll(getDimensions(viewId));
|
||||||
allElements.addAll(getMetrics(modelIds));
|
allElements.addAll(getMetrics(viewId));
|
||||||
return allElements.stream()
|
return allElements.stream()
|
||||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
public Map<Long, ViewSchema> getViewSchemaMap() {
|
||||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
if (CollectionUtils.isEmpty(viewSchemaList)) {
|
||||||
return new HashMap<>();
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
|
||||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
import com.google.common.collect.Sets;
|
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelSchema {
|
public class ViewSchema {
|
||||||
|
|
||||||
private SchemaElement model;
|
private SchemaElement view;
|
||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||||
private Set<SchemaElement> tags = new HashSet<>();
|
private Set<SchemaElement> tags = new HashSet<>();
|
||||||
private SchemaElement entity = new SchemaElement();
|
private SchemaElement entity = new SchemaElement();
|
||||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
private QueryConfig queryConfig;
|
||||||
|
|
||||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
Optional<SchemaElement> element = Optional.empty();
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
@@ -29,8 +29,8 @@ public class ModelSchema {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = Optional.ofNullable(entity);
|
element = Optional.ofNullable(entity);
|
||||||
break;
|
break;
|
||||||
case MODEL:
|
case VIEW:
|
||||||
element = Optional.of(model);
|
element = Optional.of(view);
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
@@ -61,8 +61,8 @@ public class ModelSchema {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = Optional.ofNullable(entity);
|
element = Optional.ofNullable(entity);
|
||||||
break;
|
break;
|
||||||
case MODEL:
|
case VIEW:
|
||||||
element = Optional.of(model);
|
element = Optional.of(view);
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
@@ -83,16 +83,31 @@ public class ModelSchema {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<Long> getModelClusterSet() {
|
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
|
||||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
if (queryConfig == null) {
|
||||||
return Sets.newHashSet();
|
return null;
|
||||||
}
|
}
|
||||||
Set<Long> modelClusterSet = new HashSet<>();
|
if (queryConfig.getTagTypeDefaultConfig() == null) {
|
||||||
modelRelas.forEach(modelRela -> {
|
return null;
|
||||||
modelClusterSet.add(modelRela.getToModelId());
|
}
|
||||||
modelClusterSet.add(modelRela.getFromModelId());
|
return queryConfig.getTagTypeDefaultConfig().getTimeDefaultConfig();
|
||||||
});
|
}
|
||||||
return modelClusterSet;
|
|
||||||
|
public TimeDefaultConfig getMetricTypeTimeDefaultConfig() {
|
||||||
|
if (queryConfig == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (queryConfig.getMetricTypeDefaultConfig() == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return queryConfig.getMetricTypeDefaultConfig().getTimeDefaultConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
public TagTypeDefaultConfig getTagTypeDefaultConfig() {
|
||||||
|
if (queryConfig == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return queryConfig.getTagTypeDefaultConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -13,26 +12,5 @@ public class ChatDefaultConfigReq {
|
|||||||
private List<Long> dimensionIds = new ArrayList<>();
|
private List<Long> dimensionIds = new ArrayList<>();
|
||||||
private List<Long> metricIds = new ArrayList<>();
|
private List<Long> metricIds = new ArrayList<>();
|
||||||
|
|
||||||
/**
|
|
||||||
* default time span unit
|
|
||||||
*/
|
|
||||||
private Integer unit = 1;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* default time type: day
|
|
||||||
* DAY, WEEK, MONTH, YEAR
|
|
||||||
*/
|
|
||||||
private String period = Constants.DAY;
|
|
||||||
|
|
||||||
private TimeMode timeMode = TimeMode.LAST;
|
|
||||||
|
|
||||||
public enum TimeMode {
|
|
||||||
/**
|
|
||||||
* date mode
|
|
||||||
* LAST - a certain time
|
|
||||||
* RECENT - a period time
|
|
||||||
*/
|
|
||||||
LAST, RECENT
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
import javax.validation.constraints.NotNull;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static java.time.LocalDate.now;
|
|
||||||
|
|
||||||
@ToString
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class DictLatestTaskReq {
|
|
||||||
|
|
||||||
@NotNull
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
private List<Long> dimIds;
|
|
||||||
|
|
||||||
private String createdAt = now().plusDays(-4).toString();
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
@ToString
|
|
||||||
@Data
|
|
||||||
public class DictTaskFilterReq {
|
|
||||||
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
private String name;
|
|
||||||
|
|
||||||
private String createdBy;
|
|
||||||
|
|
||||||
private String createdAt;
|
|
||||||
|
|
||||||
private TaskStatusEnum status;
|
|
||||||
}
|
|
||||||
@@ -13,7 +13,7 @@ public class PluginQueryReq {
|
|||||||
|
|
||||||
private String type;
|
private String type;
|
||||||
|
|
||||||
private String model;
|
private String view;
|
||||||
|
|
||||||
private String pattern;
|
private String pattern;
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ public class SimilarQueryReq {
|
|||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
|
||||||
private String modelId;
|
private Long viewId;
|
||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -21,11 +21,11 @@ public class ChatDefaultRichConfigResp {
|
|||||||
private Integer unit = 1;
|
private Integer unit = 1;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* default time type: day
|
* default time type:
|
||||||
* DAY, WEEK, MONTH, YEAR
|
* DAY, WEEK, MONTH, YEAR
|
||||||
*/
|
*/
|
||||||
private String period = Constants.DAY;
|
private String period = Constants.DAY;
|
||||||
|
|
||||||
private ChatDefaultConfigReq.TimeMode timeMode;
|
private TimeMode timeMode;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,12 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
public class DataInfo {
|
public class DataInfo {
|
||||||
|
|
||||||
private Integer itemId;
|
private Integer itemId;
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class EntityInfo {
|
public class EntityInfo {
|
||||||
|
|
||||||
private ModelInfo modelInfo = new ModelInfo();
|
private ViewInfo viewInfo = new ViewInfo();
|
||||||
private List<DataInfo> dimensions = new ArrayList<>();
|
private List<DataInfo> dimensions = new ArrayList<>();
|
||||||
private List<DataInfo> metrics = new ArrayList<>();
|
private List<DataInfo> metrics = new ArrayList<>();
|
||||||
private String entityId;
|
private String entityId;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -17,4 +17,6 @@ public class QueryResp {
|
|||||||
private QueryResult queryResult;
|
private QueryResult queryResult;
|
||||||
private List<SemanticParseInfo> parseInfos;
|
private List<SemanticParseInfo> parseInfos;
|
||||||
private List<SimilarQueryRecallResp> similarQueries;
|
private List<SimilarQueryRecallResp> similarQueries;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import java.io.Serializable;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelInfo extends DataInfo implements Serializable {
|
public class ViewInfo extends DataInfo implements Serializable {
|
||||||
|
|
||||||
private List<String> words;
|
private List<String> words;
|
||||||
private String primaryKey;
|
private String primaryKey;
|
||||||
@@ -21,7 +21,6 @@
|
|||||||
<groupId>org.springframework</groupId>
|
<groupId>org.springframework</groupId>
|
||||||
<artifactId>spring-context</artifactId>
|
<artifactId>spring-context</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.testng</groupId>
|
<groupId>org.testng</groupId>
|
||||||
<artifactId>testng</artifactId>
|
<artifactId>testng</artifactId>
|
||||||
@@ -72,17 +71,12 @@
|
|||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.tencent.supersonic</groupId>
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
<artifactId>chat-knowledge</artifactId>
|
<artifactId>headless-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.tencent.supersonic</groupId>
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
<artifactId>semantic-api</artifactId>
|
<artifactId>headless-core</artifactId>
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.tencent.supersonic</groupId>
|
|
||||||
<artifactId>semantic-query</artifactId>
|
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
@@ -104,6 +98,12 @@
|
|||||||
<version>${mockito-inline.version}</version>
|
<version>${mockito-inline.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
|
<artifactId>headless-server</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
|
||||||
|
|
||||||
public enum AgentToolType {
|
|
||||||
NL2SQL_RULE,
|
|
||||||
NL2SQL_LLM,
|
|
||||||
PLUGIN,
|
|
||||||
ANALYTICS
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DataAnalyticsTool extends AgentTool {
|
|
||||||
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class EntityInternalDetail {
|
|
||||||
|
|
||||||
List<DimSchemaResp> dimensionList;
|
|
||||||
List<MetricSchemaResp> metricList;
|
|
||||||
}
|
|
||||||
@@ -1,14 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -19,10 +23,13 @@ public class Agent extends RecordInfo {
|
|||||||
private String name;
|
private String name;
|
||||||
private String description;
|
private String description;
|
||||||
|
|
||||||
//0 offline, 1 online
|
/**
|
||||||
|
* 0 offline, 1 online
|
||||||
|
*/
|
||||||
private Integer status;
|
private Integer status;
|
||||||
private List<String> examples;
|
private List<String> examples;
|
||||||
private String agentConfig;
|
private String agentConfig;
|
||||||
|
|
||||||
public List<String> getTools(AgentToolType type) {
|
public List<String> getTools(AgentToolType type) {
|
||||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||||
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
||||||
@@ -45,4 +52,27 @@ public class Agent extends RecordInfo {
|
|||||||
return enableSearch != null && enableSearch == 1;
|
return enableSearch != null && enableSearch == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||||
|
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
|
||||||
|
List<String> tools = this.getTools(agentToolType);
|
||||||
|
if (CollectionUtils.isEmpty(tools)) {
|
||||||
|
return Lists.newArrayList();
|
||||||
|
}
|
||||||
|
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||||
|
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||||
|
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||||
|
return new HashSet<>();
|
||||||
|
}
|
||||||
|
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
|
||||||
|
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||||
|
.flatMap(Collection::stream)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public enum AgentToolType {
|
||||||
|
NL2SQL_RULE("基于规则Text-to-SQL"),
|
||||||
|
NL2SQL_LLM("基于大模型Text-to-SQL"),
|
||||||
|
PLUGIN("第三方插件");
|
||||||
|
|
||||||
|
private String title;
|
||||||
|
|
||||||
|
AgentToolType(String title) {
|
||||||
|
this.title = title;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<AgentToolType, String> getToolTypes() {
|
||||||
|
Map<AgentToolType, String> map = new HashMap<>();
|
||||||
|
map.put(NL2SQL_RULE, NL2SQL_RULE.title);
|
||||||
|
map.put(NL2SQL_LLM, NL2SQL_LLM.title);
|
||||||
|
map.put(PLUGIN, PLUGIN.title);
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -1,16 +1,17 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class NL2SQLTool extends AgentTool {
|
public class NL2SQLTool extends AgentTool {
|
||||||
|
|
||||||
protected List<Long> modelIds;
|
protected List<Long> viewIds;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
|
|||||||
private List<String> queryTypes;
|
private List<String> queryTypes;
|
||||||
|
|
||||||
public boolean isContainsAllModel() {
|
public boolean isContainsAllModel() {
|
||||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
package com.tencent.supersonic.knowledge.dictionary;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -11,7 +13,7 @@ import org.springframework.context.annotation.Configuration;
|
|||||||
@Data
|
@Data
|
||||||
@Configuration
|
@Configuration
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LocalFileConfig {
|
public class ChatLocalFileConfig {
|
||||||
|
|
||||||
|
|
||||||
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.knowledge.semantic;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class EntityInternalDetail {
|
||||||
|
|
||||||
|
List<DimSchemaResp> dimensionList;
|
||||||
|
List<MetricSchemaResp> metricList;
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -16,10 +16,12 @@ public class LLMParserConfig {
|
|||||||
@Value("${query2sql.path:/query2sql}")
|
@Value("${query2sql.path:/query2sql}")
|
||||||
private String queryToSqlPath;
|
private String queryToSqlPath;
|
||||||
|
|
||||||
@Value("${dimension.topn:5}")
|
@Value("${dimension.topn:10}")
|
||||||
private Integer dimensionTopN;
|
private Integer dimensionTopN;
|
||||||
|
|
||||||
@Value("${metric.topn:5}")
|
@Value("${metric.topn:10}")
|
||||||
private Integer metricTopN;
|
private Integer metricTopN;
|
||||||
|
|
||||||
|
@Value("${all.model:false}")
|
||||||
|
private Boolean allModel;
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
import com.tencent.supersonic.common.service.SysParameterService;
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -52,7 +52,7 @@ public class OptimizationConfig {
|
|||||||
@Value("${embedding.mapper.round.number:10}")
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
private int embeddingMapperRoundNumber;
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||||
private Double embeddingMapperDistanceThreshold;
|
private Double embeddingMapperDistanceThreshold;
|
||||||
|
|
||||||
@Value("${s2SQL.linking.value.switch:true}")
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
@@ -64,17 +64,17 @@ public class OptimizationConfig {
|
|||||||
@Value("${s2SQL.use.switch:true}")
|
@Value("${s2SQL.use.switch:true}")
|
||||||
private boolean useS2SqlSwitch;
|
private boolean useS2SqlSwitch;
|
||||||
|
|
||||||
@Value("${text2sql.example.num:10}")
|
@Value("${text2sql.example.num:15}")
|
||||||
private int text2sqlExampleNum;
|
private int text2sqlExampleNum;
|
||||||
|
|
||||||
@Value("${text2sql.fewShots.num:5}")
|
@Value("${text2sql.fewShots.num:10}")
|
||||||
private int text2sqlFewShotsNum;
|
private int text2sqlFewShotsNum;
|
||||||
|
|
||||||
@Value("${text2sql.self.consistency.num:2}")
|
@Value("${text2sql.self.consistency.num:5}")
|
||||||
private int text2sqlSelfConsistencyNum;
|
private int text2sqlSelfConsistencyNum;
|
||||||
|
|
||||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
@Value("${parse.show.count:3}")
|
||||||
private String text2sqlCollectionName;
|
private Integer parseShowCount;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private SysParameterService sysParameterService;
|
private SysParameterService sysParameterService;
|
||||||
@@ -147,6 +147,10 @@ public class OptimizationConfig {
|
|||||||
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Integer getParseShowCount() {
|
||||||
|
return convertValue("parse.show.count", Integer.class, parseShowCount);
|
||||||
|
}
|
||||||
|
|
||||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||||
try {
|
try {
|
||||||
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||||
@@ -1,20 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -32,23 +31,23 @@ import java.util.stream.Collectors;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||||
|
|
||||||
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
try {
|
try {
|
||||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
doCorrect(queryReq, semanticParseInfo);
|
doCorrect(queryContext, semanticParseInfo);
|
||||||
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||||
@@ -56,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
// support fieldName and field alias
|
// support fieldName and field alias
|
||||||
Map<String, String> result = dbAllFields.stream()
|
Map<String, String> result = dbAllFields.stream()
|
||||||
.filter(entry -> modelIds.contains(entry.getModel()))
|
.filter(entry -> viewId.equals(entry.getView()))
|
||||||
.flatMap(schemaElement -> {
|
.flatMap(schemaElement -> {
|
||||||
Set<String> elements = new HashSet<>();
|
Set<String> elements = new HashSet<>();
|
||||||
elements.add(schemaElement.getName());
|
elements.add(schemaElement.getName());
|
||||||
@@ -78,14 +77,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
|
||||||
|
//decide whether add order by expression field to select
|
||||||
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
|
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
|
}
|
||||||
|
|
||||||
// If there is no aggregate function in the S2SQL statement and
|
// If there is no aggregate function in the S2SQL statement and
|
||||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
@@ -97,16 +102,15 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
needAddFields.removeAll(selectFields);
|
needAddFields.removeAll(selectFields);
|
||||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
Long viewId = semanticParseInfo.getView().getView();
|
||||||
|
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
|
||||||
List<SchemaElement> metrics = getMetricElements(modelIds);
|
|
||||||
|
|
||||||
Map<String, String> metricToAggregate = metrics.stream()
|
Map<String, String> metricToAggregate = metrics.stream()
|
||||||
.map(schemaElement -> {
|
.map(schemaElement -> {
|
||||||
@@ -127,13 +131,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
|
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
return semanticSchema.getMetrics(modelIds);
|
return semanticSchema.getMetrics(viewId);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||||
|
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo);
|
||||||
|
if (!needAddGroupBy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
addGroupByFields(queryContext, semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
|
ViewService viewService = ContextUtils.getBean(ViewService.class);
|
||||||
|
ModelService modelService = ContextUtils.getBean(ModelService.class);
|
||||||
|
ViewResp viewResp = viewService.getView(viewId);
|
||||||
|
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
MetaFilter metaFilter = new MetaFilter();
|
||||||
|
metaFilter.setIds(modelIds);
|
||||||
|
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
|
||||||
|
for (ModelResp modelResp : modelRespList) {
|
||||||
|
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
|
||||||
|
for (Dim dim : dimList) {
|
||||||
|
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//add dimension group by
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
// check has distinct
|
||||||
|
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||||
|
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//add alias field name
|
||||||
|
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||||
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// if only date in select not add group by.
|
||||||
|
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||||
|
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||||
|
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||||
|
.flatMap(
|
||||||
|
schemaElement -> {
|
||||||
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
}
|
||||||
|
).collect(Collectors.toSet());
|
||||||
|
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
|
//add dimension group by
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
//add alias field name
|
||||||
|
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||||
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
|
Set<String> groupByFields = selectFields.stream()
|
||||||
|
.filter(field -> dimensions.contains(field))
|
||||||
|
.filter(field -> {
|
||||||
|
if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
})
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||||
|
addAggregate(queryContext, semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||||
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
|
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
addAggregateToMetric(queryContext, semanticParseInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class HavingCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
//add aggregate to all metric
|
||||||
|
addHaving(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
|
//decide whether add having expression field to select
|
||||||
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
|
addHavingToSelect(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
Long viewId = semanticParseInfo.getView().getView();
|
||||||
|
|
||||||
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
|
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||||
|
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||||
|
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||||
|
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,30 +1,31 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.ParseResult;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform schema corrections on the Schema information in S2QL.
|
* Perform schema corrections on the Schema information in S2SQL.
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
correctAggFunction(semanticParseInfo);
|
correctAggFunction(semanticParseInfo);
|
||||||
|
|
||||||
@@ -34,26 +35,26 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
updateFieldValueByLinkingValue(semanticParseInfo);
|
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||||
|
|
||||||
correctFieldName(semanticParseInfo);
|
correctFieldName(queryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
|
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +70,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
|
||||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
)));
|
)));
|
||||||
|
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -14,10 +14,10 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class SelectCorrector extends BaseSemanticCorrector {
|
public class SelectCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||||
&& !CollectionUtils.isEmpty(selectFields)
|
&& !CollectionUtils.isEmpty(selectFields)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.component;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic corrector checks validity of extracted semantic information and
|
* A semantic corrector checks validity of extracted semantic information and
|
||||||
@@ -9,5 +9,5 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|||||||
*/
|
*/
|
||||||
public interface SemanticCorrector {
|
public interface SemanticCorrector {
|
||||||
|
|
||||||
void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
|
import java.util.Objects;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the time in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class TimeCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
parserDateDiffFunction(semanticParseInfo);
|
||||||
|
|
||||||
|
addLowerBoundDate(semanticParseInfo);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||||
|
if (Objects.isNull(dateBoundInfo)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
|
||||||
|
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
|
||||||
|
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
|
||||||
|
String upperDate = dateBoundInfo.getUpperDate();
|
||||||
|
try {
|
||||||
|
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
|
||||||
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("parseCondExpression", e);
|
||||||
|
}
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,25 +1,24 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.S2SqlDateHelper;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -27,7 +26,6 @@ import java.util.HashMap;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -37,19 +35,17 @@ import java.util.stream.Collectors;
|
|||||||
public class WhereCorrector extends BaseSemanticCorrector {
|
public class WhereCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
addDateIfNotExist(semanticParseInfo);
|
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
parserDateDiffFunction(semanticParseInfo);
|
addQueryFilter(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
addQueryFilter(queryReq, semanticParseInfo);
|
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
updateFieldValueByTechName(semanticParseInfo);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
||||||
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|
||||||
@@ -61,26 +57,29 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
log.error("parseCondExpression", e);
|
log.error("parseCondExpression", e);
|
||||||
}
|
}
|
||||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
String currentDate = S2SqlDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
||||||
if (StringUtils.isNotBlank(currentDate)) {
|
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
|
||||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
||||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
String dateChName = TimeDimensionEnum.DAY.getChName();
|
||||||
|
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
|
||||||
|
startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
||||||
|
try {
|
||||||
|
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||||
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("parseCondExpression:{}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
@@ -100,17 +99,17 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
|
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||||
aliasAndBizNameToTechName);
|
aliasAndBizNameToTechName);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
@@ -1,23 +1,22 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.BeanUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseMapper implements SchemaMapper {
|
public abstract class BaseMapper implements SchemaMapper {
|
||||||
@@ -27,7 +26,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
|
|
||||||
String simpleName = this.getClass().getSimpleName();
|
String simpleName = this.getClass().getSimpleName();
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
doMap(queryContext);
|
doMap(queryContext);
|
||||||
@@ -36,13 +35,13 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
long cost = System.currentTimeMillis() - startTime;
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void doMap(QueryContext queryContext);
|
public abstract void doMap(QueryContext queryContext);
|
||||||
|
|
||||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
|
||||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = modelElementMatches.get(modelId);
|
schemaElementMatches = modelElementMatches.get(modelId);
|
||||||
@@ -68,14 +67,14 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
|
||||||
|
SemanticSchema semanticSchema) {
|
||||||
SchemaElement element = new SchemaElement();
|
SchemaElement element = new SchemaElement();
|
||||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
if (Objects.isNull(viewSchema)) {
|
||||||
if (Objects.isNull(modelSchema)) {
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
|
||||||
if (Objects.isNull(elementDb)) {
|
if (Objects.isNull(elementDb)) {
|
||||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||||
return null;
|
return null;
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -27,24 +27,25 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
private MapperHelper mapperHelper;
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
String text = queryContext.getRequest().getQueryText();
|
Set<Long> detectViewIds) {
|
||||||
|
String text = queryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
|
||||||
|
|
||||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
List<T> detects = detect(queryContext, terms, detectViewIds);
|
||||||
Map<MatchText, List<T>> result = new HashMap<>();
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
String text = queryContext.getRequest().getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
Set<T> results = new HashSet<>();
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
Set<String> detectSegments = new HashSet<>();
|
Set<String> detectSegments = new HashSet<>();
|
||||||
@@ -55,25 +56,26 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index);
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
detectSegments.add(detectSegment);
|
detectSegments.add(detectSegment);
|
||||||
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
}
|
}
|
||||||
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
detectByBatch(queryContext, results, detectViewIds, detectSegments);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||||
return terms.stream().sorted(Comparator.comparing(Term::length))
|
return terms.stream().sorted(Comparator.comparing(S2Term::length))
|
||||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
|
||||||
|
(value1, value2) -> value2));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||||
@@ -101,10 +103,10 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
||||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||||
terms = filterByModelIds(terms, detectModelIds);
|
terms = filterByViewId(terms, viewIds);
|
||||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
|
||||||
List<T> matches = new ArrayList<>();
|
List<T> matches = new ArrayList<>();
|
||||||
if (Objects.isNull(matchResult)) {
|
if (Objects.isNull(matchResult)) {
|
||||||
return matches;
|
return matches;
|
||||||
@@ -119,27 +121,27 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
return matches;
|
return matches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
|
||||||
logTerms(terms);
|
logTerms(terms);
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
if (CollectionUtils.isNotEmpty(viewIds)) {
|
||||||
terms = terms.stream().filter(term -> {
|
terms = terms.stream().filter(term -> {
|
||||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
Long viewId = NatureHelper.getViewId(term.getNature().toString());
|
||||||
if (Objects.nonNull(modelId)) {
|
if (Objects.nonNull(viewId)) {
|
||||||
return detectModelIds.contains(modelId);
|
return viewIds.contains(viewId);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
log.info("terms filter by viewId:{}", viewIds);
|
||||||
logTerms(terms);
|
logTerms(terms);
|
||||||
}
|
}
|
||||||
return terms;
|
return terms;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void logTerms(List<Term> terms) {
|
public void logTerms(List<S2Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(terms)) {
|
if (CollectionUtils.isEmpty(terms)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (Term term : terms) {
|
for (S2Term term : terms) {
|
||||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
public abstract String getMapKey(T a);
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
String detectSegment, int offset);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -15,11 +20,6 @@ import java.util.Map;
|
|||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||||
@@ -33,15 +33,13 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
@Autowired
|
@Autowired
|
||||||
private MapperHelper mapperHelper;
|
private MapperHelper mapperHelper;
|
||||||
@Autowired
|
|
||||||
private SchemaService schemaService;
|
|
||||||
private List<SchemaElement> allElements;
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
|
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectModelIds) {
|
Set<Long> detectViewIds) {
|
||||||
this.allElements = getSchemaElements();
|
this.allElements = getSchemaElements(queryContext);
|
||||||
return super.match(queryContext, terms, detectModelIds);
|
return super.match(queryContext, terms, detectViewIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -56,16 +54,13 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex, Integer index, int offset) {
|
String detectSegment, int offset) {
|
||||||
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
|
||||||
if (StringUtils.isBlank(detectSegment)) {
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
|
||||||
|
|
||||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||||
|
|
||||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
@@ -75,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||||
schemaElements = schemaElements.stream()
|
schemaElements = schemaElements.stream()
|
||||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
for (SchemaElement schemaElement : schemaElements) {
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
@@ -90,10 +85,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getSchemaElements() {
|
private List<SchemaElement> getSchemaElements(QueryContext queryContext) {
|
||||||
List<SchemaElement> allElements = new ArrayList<>();
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
allElements.addAll(queryContext.getSemanticSchema().getDimensions());
|
||||||
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
allElements.addAll(queryContext.getSemanticSchema().getMetrics());
|
||||||
return allElements;
|
return allElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +96,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||||
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
|
||||||
|
|
||||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||||
|
|
||||||
@@ -1,18 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import java.util.List;
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A mapper that recognizes schema elements with vector embedding.
|
* A mapper that recognizes schema elements with vector embedding.
|
||||||
@@ -23,8 +24,9 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
//1. query from embedding by queryText
|
//1. query from embedding by queryText
|
||||||
String queryText = queryContext.getRequest().getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||||
|
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||||
|
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||||
@@ -34,16 +36,13 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
//2. build SchemaElementMatch by info
|
//2. build SchemaElementMatch by info
|
||||||
for (EmbeddingResult matchResult : matchResults) {
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
|
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
if (Objects.isNull(viewId)) {
|
||||||
SchemaElement.class);
|
|
||||||
|
|
||||||
String modelIdStr = matchResult.getMetadata().get("modelId");
|
|
||||||
if (StringUtils.isBlank(modelIdStr)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
long modelId = Long.parseLong(modelIdStr);
|
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||||
|
queryContext.getSemanticSchema());
|
||||||
if (schemaElement == null) {
|
if (schemaElement == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -55,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
.detectWord(matchResult.getDetectWord())
|
.detectWord(matchResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
//3. add to mapInfo
|
//3. add to mapInfo
|
||||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,18 +1,16 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -35,7 +33,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
@Autowired
|
||||||
|
private MetaEmbeddingService metaEmbeddingService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||||
@@ -49,7 +48,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||||
|
String detectSegment, int offset) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
|
|
||||||
List<String> queryTextsList = detectSegments.stream()
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
@@ -63,49 +68,29 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
optimizationConfig.getEmbeddingMapperBatch());
|
optimizationConfig.getEmbeddingMapperBatch());
|
||||||
|
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
List<String> queryTextsSub) {
|
List<String> queryTextsSub) {
|
||||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
Map<String, String> filterCondition = null;
|
|
||||||
// step1. build query params
|
// step1. build query params
|
||||||
// if only one modelId, add to filterCondition
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
|
||||||
filterCondition = new HashMap<>();
|
|
||||||
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
|
||||||
.queryTextsList(queryTextsSub)
|
|
||||||
.filterCondition(filterCondition)
|
|
||||||
.queryEmbeddings(null)
|
|
||||||
.build();
|
|
||||||
// step2. retrieveQuery by detectSegment
|
// step2. retrieveQuery by detectSegment
|
||||||
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// step3. build EmbeddingResults. filter by modelId
|
// step3. build EmbeddingResults
|
||||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||||
.map(retrieveQueryResult -> {
|
.map(retrieveQueryResult -> {
|
||||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
|
||||||
retrievals.removeIf(retrieval -> {
|
|
||||||
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
|
||||||
if (StringUtils.isBlank(modelIdStr)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return retrieveQueryResult;
|
return retrieveQueryResult;
|
||||||
})
|
})
|
||||||
@@ -116,6 +101,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||||
embeddingResult.setName(retrieval.getQuery());
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
|
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||||
|
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
|
||||||
|
embeddingResult.setMetadata(convertedMap);
|
||||||
return embeddingResult;
|
return embeddingResult;
|
||||||
}))
|
}))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -129,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
selectResultInOneRound(results, oneRoundResults);
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
|
||||||
Integer startIndex, Integer index, int offset) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -25,12 +24,12 @@ public class EntityMapper extends BaseMapper {
|
|||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
|
||||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
|
||||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElement entity = getEntity(modelId);
|
SchemaElement entity = getEntity(viewId, queryContext);
|
||||||
if (entity == null || entity.getId() == null) {
|
if (entity == null || entity.getId() == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -66,9 +65,9 @@ public class EntityMapper extends BaseMapper {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private SchemaElement getEntity(Long modelId) {
|
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
|
||||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
|
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||||
return modelSchema.getEntity();
|
return modelSchema.getEntity();
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.LinkedHashSet;
|
import java.util.LinkedHashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -35,18 +34,20 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private KnowledgeService knowledgeService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectModelIds) {
|
Set<Long> detectViewIds) {
|
||||||
QueryReq queryReq = queryContext.getRequest();
|
String text = queryContext.getQueryText();
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
|
||||||
|
|
||||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
|
||||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
@@ -59,22 +60,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex, Integer index, int offset) {
|
String detectSegment, int offset) {
|
||||||
QueryReq queryReq = queryContext.getRequest();
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
Integer agentId = queryReq.getAgentId();
|
|
||||||
String detectSegment = text.substring(startIndex, index);
|
|
||||||
|
|
||||||
// step1. pre search
|
// step1. pre search
|
||||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||||
agentId,
|
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
// step2. suffix search
|
// step2. suffix search
|
||||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||||
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
|
|
||||||
@@ -1,24 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
import org.springframework.util.CollectionUtils;
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A mapper that recognizes schema elements with keyword.
|
* A mapper that recognizes schema elements with keyword.
|
||||||
@@ -29,13 +31,14 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
String queryText = queryContext.getRequest().getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
//1.hanlpDict Match
|
//1.hanlpDict Match
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||||
|
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||||
|
|
||||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||||
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext.getMapInfo(), terms);
|
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms);
|
||||||
|
|
||||||
//2.database Match
|
//2.database Match
|
||||||
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||||
@@ -44,8 +47,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
|
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, SchemaMapInfo schemaMap,
|
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||||
List<Term> terms) {
|
List<S2Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(mapResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -56,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
|
|
||||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||||
for (String nature : hanlpMapResult.getNatures()) {
|
for (String nature : hanlpMapResult.getNatures()) {
|
||||||
Long modelId = NatureHelper.getModelId(nature);
|
Long viewId = NatureHelper.getViewId(nature);
|
||||||
if (Objects.isNull(modelId)) {
|
if (Objects.isNull(viewId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||||
@@ -65,7 +68,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
SchemaElement element = getSchemaElement(viewId, elementType,
|
||||||
|
elementID, queryContext.getSemanticSchema());
|
||||||
if (element == null) {
|
if (element == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -81,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.detectWord(hanlpMapResult.getDetectWord())
|
.detectWord(hanlpMapResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,12 +106,12 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||||
.build();
|
.build();
|
||||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
|
||||||
if (CollectionUtils.isEmpty(elements)) {
|
if (CollectionUtils.isEmpty(elements)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import lombok.Data;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -14,10 +17,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Service
|
@Service
|
||||||
@@ -37,8 +36,8 @@ public class MapperHelper {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer getStepOffset(List<Term> termList, Integer index) {
|
public Integer getStepOffset(List<S2Term> termList, Integer index) {
|
||||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
|
||||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||||
|
|
||||||
for (int j = 0; j < termList.size() - 1; j++) {
|
for (int j = 0; j < termList.size() - 1; j++) {
|
||||||
@@ -63,7 +62,7 @@ public class MapperHelper {
|
|||||||
*/
|
*/
|
||||||
public boolean existDimensionValues(List<String> natures) {
|
public boolean existDimensionValues(List<String> natures) {
|
||||||
for (String nature : natures) {
|
for (String nature : natures) {
|
||||||
if (NatureHelper.isDimensionValueModelId(nature)) {
|
if (NatureHelper.isDimensionValueViewId(nature)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -83,34 +82,33 @@ public class MapperHelper {
|
|||||||
detectSegment.length());
|
detectSegment.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<Long> getModelIds(QueryReq request) {
|
public Set<Long> getViewIds(Long viewId, Agent agent) {
|
||||||
|
|
||||||
Long modelId = request.getModelId();
|
Set<Long> detectViewIds = new HashSet<>();
|
||||||
|
if (Objects.nonNull(agent)) {
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
detectViewIds = agent.getViewIds(null);
|
||||||
|
}
|
||||||
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
|
|
||||||
//contains all
|
//contains all
|
||||||
if (agentService.containsAllModel(detectModelIds)) {
|
if (Agent.containsAllModel(detectViewIds)) {
|
||||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||||
Set<Long> result = new HashSet<>();
|
Set<Long> result = new HashSet<>();
|
||||||
result.add(modelId);
|
result.add(viewId);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Objects.nonNull(detectModelIds)) {
|
if (Objects.nonNull(detectViewIds)) {
|
||||||
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
|
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
|
||||||
if (detectModelIds.contains(modelId)) {
|
if (detectViewIds.contains(viewId)) {
|
||||||
Set<Long> result = new HashSet<>();
|
Set<Long> result = new HashSet<>();
|
||||||
result.add(modelId);
|
result.add(viewId);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return detectModelIds;
|
return detectViewIds;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -12,6 +13,6 @@ import java.util.Set;
|
|||||||
*/
|
*/
|
||||||
public interface MatchStrategy<T> {
|
public interface MatchStrategy<T> {
|
||||||
|
|
||||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
@@ -10,10 +10,10 @@ import lombok.ToString;
|
|||||||
public class ModelWithSemanticType implements Serializable {
|
public class ModelWithSemanticType implements Serializable {
|
||||||
|
|
||||||
private Long model;
|
private Long model;
|
||||||
private SchemaElementType semanticType;
|
private SchemaElementType schemaElementType;
|
||||||
|
|
||||||
public ModelWithSemanticType(Long model, SchemaElementType semanticType) {
|
public ModelWithSemanticType(Long model, SchemaElementType schemaElementType) {
|
||||||
this.model = model;
|
this.model = model;
|
||||||
this.semanticType = semanticType;
|
this.schemaElementType = schemaElementType;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,46 +1,45 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class QueryFilterMapper implements SchemaMapper {
|
public class QueryFilterMapper implements SchemaMapper {
|
||||||
|
|
||||||
private double similarity = 1.0;
|
private double similarity = 1.0;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void map(QueryContext queryContext) {
|
||||||
QueryReq queryReq = queryContext.getRequest();
|
Long viewId = queryContext.getViewId();
|
||||||
Long modelId = queryReq.getModelId();
|
if (viewId == null || viewId <= 0) {
|
||||||
if (modelId == null || modelId <= 0) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
clearOtherSchemaElementMatch(modelId, schemaMapInfo);
|
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
|
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||||
}
|
}
|
||||||
addValueSchemaElementMatch(queryContext, schemaElementMatches, queryReq.getQueryFilters());
|
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getModelElementMatches().entrySet()) {
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||||
if (!entry.getKey().equals(modelId)) {
|
if (!entry.getKey().equals(modelId)) {
|
||||||
entry.getValue().clear();
|
entry.getValue().clear();
|
||||||
}
|
}
|
||||||
@@ -48,12 +47,12 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
|
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
|
||||||
List<SchemaElementMatch> candidateElementMatches,
|
List<SchemaElementMatch> candidateElementMatches) {
|
||||||
QueryFilters queryFilter) {
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
|
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
return candidateElementMatches;
|
return candidateElementMatches;
|
||||||
}
|
}
|
||||||
for (QueryFilter filter : queryFilter.getFilters()) {
|
for (QueryFilter filter : queryFilters.getFilters()) {
|
||||||
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -62,7 +61,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
.name(String.valueOf(filter.getValue()))
|
.name(String.valueOf(filter.getValue()))
|
||||||
.type(SchemaElementType.VALUE)
|
.type(SchemaElementType.VALUE)
|
||||||
.bizName(filter.getBizName())
|
.bizName(filter.getBizName())
|
||||||
.model(queryContext.getRequest().getModelId())
|
.view(queryContext.getViewId())
|
||||||
.build();
|
.build();
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
@@ -77,7 +76,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
||||||
List<SchemaElementMatch> schemaElementMatches) {
|
List<SchemaElementMatch> schemaElementMatches) {
|
||||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
|
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.component;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.core.knowledge.SearchService;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -15,6 +15,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -26,11 +27,13 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
|
|
||||||
private static final int SEARCH_SIZE = 3;
|
private static final int SEARCH_SIZE = 3;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private KnowledgeService knowledgeService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||||
Set<Long> detectModelIds) {
|
Set<Long> detectViewIds) {
|
||||||
QueryReq queryReq = queryContext.getRequest();
|
String text = queryContext.getQueryText();
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||||
|
|
||||||
List<Integer> detectIndexList = Lists.newArrayList();
|
List<Integer> detectIndexList = Lists.newArrayList();
|
||||||
@@ -53,10 +56,10 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
String detectSegment = text.substring(detectIndex);
|
String detectSegment = text.substring(detectIndex);
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
SearchService.SEARCH_SIZE, detectViewIds);
|
||||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||||
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
detectSegment, SEARCH_SIZE, detectViewIds);
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||||
@@ -90,9 +93,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex,
|
String detectSegment, int offset) {
|
||||||
Integer i, int offset) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,22 +1,24 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionPromptGenerator;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LLMProxy based on langchain4j Java version.
|
* LLMProxy based on langchain4j Java version.
|
||||||
*/
|
*/
|
||||||
@@ -24,6 +26,8 @@ import org.springframework.stereotype.Component;
|
|||||||
@Component
|
@Component
|
||||||
public class JavaLLMProxy implements LLMProxy {
|
public class JavaLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isSkip(QueryContext queryContext) {
|
public boolean isSkip(QueryContext queryContext) {
|
||||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
@@ -34,17 +38,14 @@ public class JavaLLMProxy implements LLMProxy {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||||
|
|
||||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getViewName();
|
||||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
LLMResp result = sqlGeneration.generation(llmReq, viewId);
|
||||||
|
|
||||||
LLMResp result = new LLMResp();
|
|
||||||
result.setQuery(llmReq.getQueryText());
|
result.setQuery(llmReq.getQueryText());
|
||||||
result.setModelName(modelName);
|
result.setModelName(modelName);
|
||||||
result.setSqlWeight(sqlWeight);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,14 +54,13 @@ public class JavaLLMProxy implements LLMProxy {
|
|||||||
|
|
||||||
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
||||||
|
|
||||||
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||||
functionReq.getPluginConfigs());
|
functionReq.getPluginConfigs());
|
||||||
|
keyPipelineLog.info("functionCallPrompt:{}", functionCallPrompt);
|
||||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
String response = chatLanguageModel.generate(functionCallPrompt);
|
||||||
|
keyPipelineLog.info("functionCall response:{}", response);
|
||||||
String functionSelect = chatLanguageModel.generate(functionCallPrompt);
|
return OutputFormat.functionCallParse(response);
|
||||||
|
|
||||||
return OutputFormat.functionCallParse(functionSelect);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||||
|
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||||
|
* could be used.
|
||||||
|
*/
|
||||||
|
public interface LLMProxy {
|
||||||
|
|
||||||
|
boolean isSkip(QueryContext queryContext);
|
||||||
|
|
||||||
|
LLMResp query2sql(LLMReq llmReq, Long viewId);
|
||||||
|
|
||||||
|
FunctionResp requestFunction(FunctionReq functionReq);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,19 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSON;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import java.net.URI;
|
|
||||||
import java.net.URL;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
@@ -23,6 +25,10 @@ import org.springframework.stereotype.Component;
|
|||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PythonLLMProxy sends requests to LangChain-based python service.
|
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||||
*/
|
*/
|
||||||
@@ -30,6 +36,8 @@ import org.springframework.web.util.UriComponentsBuilder;
|
|||||||
@Component
|
@Component
|
||||||
public class PythonLLMProxy implements LLMProxy {
|
public class PythonLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isSkip(QueryContext queryContext) {
|
public boolean isSkip(QueryContext queryContext) {
|
||||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
@@ -40,10 +48,10 @@ public class PythonLLMProxy implements LLMProxy {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||||
|
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
try {
|
try {
|
||||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
|
||||||
@@ -55,9 +63,15 @@ public class PythonLLMProxy implements LLMProxy {
|
|||||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||||
LLMResp.class);
|
LLMResp.class);
|
||||||
|
|
||||||
|
LLMResp llmResp = responseEntity.getBody();
|
||||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||||
return responseEntity.getBody();
|
keyPipelineLog.info("LLMResp:{}", llmResp);
|
||||||
|
|
||||||
|
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||||
|
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||||
|
}
|
||||||
|
return llmResp;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("requestLLM error", e);
|
log.error("requestLLM error", e);
|
||||||
}
|
}
|
||||||
@@ -75,10 +89,12 @@ public class PythonLLMProxy implements LLMProxy {
|
|||||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||||
try {
|
try {
|
||||||
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||||
|
keyPipelineLog.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||||
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||||
FunctionResp.class);
|
FunctionResp.class);
|
||||||
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
||||||
System.currentTimeMillis() - startTime);
|
System.currentTimeMillis() - startTime);
|
||||||
|
keyPipelineLog.info("response:{}", responseEntity.getBody());
|
||||||
return responseEntity.getBody();
|
return responseEntity.getBody();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("requestFunction error", e);
|
log.error("requestFunction error", e);
|
||||||
@@ -1,21 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -35,42 +32,48 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
|
||||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||||
User user = queryContext.getRequest().getUser();
|
User user = queryContext.getUser();
|
||||||
|
|
||||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||||
// 1.init S2SQL
|
// 1.init S2SQL
|
||||||
semanticQuery.initS2Sql(user);
|
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
|
||||||
// 2.set queryType
|
// 2.set queryType
|
||||||
QueryType queryType = getQueryType(semanticQuery);
|
QueryType queryType = getQueryType(queryContext, semanticQuery);
|
||||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
|
||||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
//1. entity queryType
|
//1. entity queryType
|
||||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
Long viewId = parseInfo.getViewId();
|
||||||
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||||
//If all the fields in the SELECT statement are of tag type.
|
//If all the fields in the SELECT statement are of tag type.
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
.collect(Collectors.toList());
|
||||||
if (CollectionUtils.isNotEmpty(selectFields)) {
|
|
||||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||||
|
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||||
|
return QueryType.ID;
|
||||||
|
}
|
||||||
|
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||||
return QueryType.TAG;
|
return QueryType.TAG;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//2. metric queryType
|
//2. metric queryType
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
|
||||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ public class SatisfactionChecker {
|
|||||||
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.api.component;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic parser understands user queries and extracts semantic information.
|
* A semantic parser understands user queries and extracts semantic information.
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin;
|
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||||
|
|
||||||
public enum ParseMode {
|
public enum ParseMode {
|
||||||
|
|
||||||
@@ -1,27 +1,25 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin;
|
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -36,7 +34,7 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||||
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||||
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -57,16 +55,14 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
|
|
||||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||||
Plugin plugin = pluginRecallResult.getPlugin();
|
Plugin plugin = pluginRecallResult.getPlugin();
|
||||||
Set<Long> modelIds = pluginRecallResult.getModelIds();
|
Set<Long> viewIds = pluginRecallResult.getViewIds();
|
||||||
if (plugin.isContainsAllModel()) {
|
if (plugin.isContainsAllModel()) {
|
||||||
modelIds = Sets.newHashSet(-1L);
|
viewIds = Sets.newHashSet(-1L);
|
||||||
}
|
}
|
||||||
for (Long modelId : modelIds) {
|
for (Long viewId : viewIds) {
|
||||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
|
||||||
queryContext.getRequest(),
|
queryContext, pluginRecallResult.getDistance());
|
||||||
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
|
|
||||||
pluginRecallResult.getDistance());
|
|
||||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||||
pluginQuery.setParseInfo(semanticParseInfo);
|
pluginQuery.setParseInfo(semanticParseInfo);
|
||||||
@@ -75,25 +71,28 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
|
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
|
||||||
List<SchemaElementMatch> schemaElementMatches, double distance) {
|
QueryContext queryContext, double distance) {
|
||||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||||
modelId = plugin.getModelList().get(0);
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
|
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
|
||||||
|
viewId = plugin.getViewList().get(0);
|
||||||
}
|
}
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||||
pluginParseResult.setPlugin(plugin);
|
pluginParseResult.setPlugin(plugin);
|
||||||
pluginParseResult.setRequest(queryReq);
|
pluginParseResult.setQueryFilters(queryFilters);
|
||||||
pluginParseResult.setDistance(distance);
|
pluginParseResult.setDistance(distance);
|
||||||
|
pluginParseResult.setQueryText(queryContext.getQueryText());
|
||||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||||
properties.put("type", "plugin");
|
properties.put("type", "plugin");
|
||||||
properties.put("name", plugin.getName());
|
properties.put("name", plugin.getName());
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
@@ -42,7 +42,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||||
String text = queryContext.getRequest().getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
return null;
|
return null;
|
||||||
@@ -57,15 +57,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
|||||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||||
log.info("embedding plugin resolve: {}", pair);
|
log.info("embedding plugin resolve: {}", pair);
|
||||||
if (pair.getLeft()) {
|
if (pair.getLeft()) {
|
||||||
Set<Long> modelList = pair.getRight();
|
Set<Long> viewList = pair.getRight();
|
||||||
if (CollectionUtils.isEmpty(modelList)) {
|
if (CollectionUtils.isEmpty(viewList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double distance = embeddingRetrieval.getDistance();
|
double distance = embeddingRetrieval.getDistance();
|
||||||
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||||
return PluginRecallResult.builder()
|
return PluginRecallResult.builder()
|
||||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -1,16 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.chat.service.PluginService;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -20,7 +19,6 @@ import org.springframework.util.CollectionUtils;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -36,7 +34,7 @@ public class FunctionCallParser extends PluginParser {
|
|||||||
String functionUrl = functionCallConfig.getUrl();
|
String functionUrl = functionCallConfig.getUrl();
|
||||||
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||||
queryContext.getRequest().getQueryText());
|
queryContext.getQueryText());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
List<Plugin> plugins = getPluginList(queryContext);
|
List<Plugin> plugins = getPluginList(queryContext);
|
||||||
@@ -45,35 +43,33 @@ public class FunctionCallParser extends PluginParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
|
||||||
FunctionResp functionResp = functionCall(queryContext);
|
FunctionResp functionResp = functionCall(queryContext);
|
||||||
if (skipFunction(functionResp)) {
|
if (skipFunction(functionResp)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
log.info("requestFunction result:{}", functionResp.getToolSelection());
|
log.info("requestFunction result:{}", functionResp.getToolSelection());
|
||||||
String toolSelection = functionResp.getToolSelection();
|
String toolSelection = functionResp.getToolSelection();
|
||||||
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
|
Plugin plugin = queryContext.getNameToPlugin().get(toolSelection);
|
||||||
if (!pluginOptional.isPresent()) {
|
if (Objects.isNull(plugin)) {
|
||||||
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
|
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
Plugin plugin = pluginOptional.get();
|
|
||||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||||
if (pluginResolveResult.getLeft()) {
|
if (pluginResolveResult.getLeft()) {
|
||||||
Set<Long> modelList = pluginResolveResult.getRight();
|
Set<Long> viewList = pluginResolveResult.getRight();
|
||||||
if (CollectionUtils.isEmpty(modelList)) {
|
if (CollectionUtils.isEmpty(viewList)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
double score = queryContext.getRequest().getQueryText().length();
|
double score = queryContext.getQueryText().length();
|
||||||
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
|
return PluginRecallResult.builder().plugin(plugin).viewIds(viewList).score(score).build();
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FunctionResp functionCall(QueryContext queryContext) {
|
public FunctionResp functionCall(QueryContext queryContext) {
|
||||||
List<PluginParseConfig> pluginToFunctionCall =
|
List<PluginParseConfig> pluginToFunctionCall =
|
||||||
getPluginToFunctionCall(queryContext.getRequest().getModelId(), queryContext);
|
getPluginToFunctionCall(queryContext.getViewId(), queryContext);
|
||||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||||
log.info("function call parser, plugin is empty, skip");
|
log.info("function call parser, plugin is empty, skip");
|
||||||
return null;
|
return null;
|
||||||
@@ -83,7 +79,7 @@ public class FunctionCallParser extends PluginParser {
|
|||||||
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
|
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
|
||||||
} else {
|
} else {
|
||||||
FunctionReq functionReq = FunctionReq.builder()
|
FunctionReq functionReq = FunctionReq.builder()
|
||||||
.queryText(queryContext.getRequest().getQueryText())
|
.queryText(queryContext.getQueryText())
|
||||||
.pluginConfigs(pluginToFunctionCall).build();
|
.pluginConfigs(pluginToFunctionCall).build();
|
||||||
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
|
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.InputFormat;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.InputFormat;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user