(improvement)(chat) Split chat into three modules: server, api, and core. (#594)

This commit is contained in:
lexluo09
2024-01-04 16:56:49 +08:00
committed by GitHub
parent 0858c13365
commit 023e84c420
337 changed files with 2407 additions and 2715 deletions

View File

@@ -0,0 +1,334 @@
package com.hankcs.hanlp.collection.trie.bintrie;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.tencent.supersonic.chat.core.knowledge.LoadRemoveService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.AbstractMap;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class BaseNode<V> implements Comparable<BaseNode> {
/**
* 状态数组,方便读取的时候用
*/
static final Status[] ARRAY_STATUS = Status.values();
private static final Logger logger = LoggerFactory.getLogger(BaseNode.class);
/**
* 子节点
*/
protected BaseNode[] child;
/**
* 节点状态
*/
protected Status status;
/**
* 节点代表的字符
*/
protected char c;
/**
* 节点代表的值
*/
protected V value;
protected String prefix = null;
public BaseNode<V> transition(String path, int begin) {
BaseNode<V> cur = this;
for (int i = begin; i < path.length(); ++i) {
cur = cur.getChild(path.charAt(i));
if (cur == null || cur.status == Status.UNDEFINED_0) {
return null;
}
}
return cur;
}
public BaseNode<V> transition(char[] path, int begin) {
BaseNode<V> cur = this;
for (int i = begin; i < path.length; ++i) {
cur = cur.getChild(path[i]);
if (cur == null || cur.status == Status.UNDEFINED_0) {
return null;
}
}
return cur;
}
/**
* 转移状态
*
* @param path
* @return
*/
public BaseNode<V> transition(char path) {
BaseNode<V> cur = this;
cur = cur.getChild(path);
if (cur == null || cur.status == Status.UNDEFINED_0) {
return null;
}
return cur;
}
/**
* 添加子节点
*
* @return true-新增了节点 false-修改了现有节点
*/
protected abstract boolean addChild(BaseNode node);
/**
* 是否含有子节点
*
* @param c 子节点的char
* @return 是否含有
*/
protected boolean hasChild(char c) {
return getChild(c) != null;
}
protected char getChar() {
return c;
}
/**
* 获取子节点
*
* @param c 子节点的char
* @return 子节点
*/
public abstract BaseNode getChild(char c);
/**
* 获取节点对应的值
*
* @return 值
*/
public final V getValue() {
return value;
}
/**
* 设置节点对应的值
*
* @param value 值
*/
public final void setValue(V value) {
this.value = value;
}
@Override
public int compareTo(BaseNode other) {
return compareTo(other.getChar());
}
/**
* 重载,与字符的比较
*
* @param other
* @return
*/
public int compareTo(char other) {
if (this.c > other) {
return 1;
}
if (this.c < other) {
return -1;
}
return 0;
}
/**
* 获取节点的成词状态
*
* @return
*/
public Status getStatus() {
return status;
}
protected void walk(StringBuilder sb, Set<Map.Entry<String, V>> entrySet) {
sb.append(c);
if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) {
entrySet.add(new TrieEntry(sb.toString(), value));
}
if (child == null) {
return;
}
for (BaseNode node : child) {
if (node == null) {
continue;
}
node.walk(new StringBuilder(sb.toString()), entrySet);
}
}
protected void walkToSave(DataOutputStream out) throws IOException {
out.writeChar(c);
out.writeInt(status.ordinal());
int childSize = 0;
if (child != null) {
childSize = child.length;
}
out.writeInt(childSize);
if (child == null) {
return;
}
for (BaseNode node : child) {
node.walkToSave(out);
}
}
protected void walkToSave(ObjectOutput out) throws IOException {
out.writeChar(c);
out.writeInt(status.ordinal());
if (status == Status.WORD_END_3 || status == Status.WORD_MIDDLE_2) {
out.writeObject(value);
}
int childSize = 0;
if (child != null) {
childSize = child.length;
}
out.writeInt(childSize);
if (child == null) {
return;
}
for (BaseNode node : child) {
node.walkToSave(out);
}
}
protected void walkToLoad(ByteArray byteArray, _ValueArray<V> valueArray) {
c = byteArray.nextChar();
status = ARRAY_STATUS[byteArray.nextInt()];
if (status == Status.WORD_END_3 || status == Status.WORD_MIDDLE_2) {
value = valueArray.nextValue();
}
int childSize = byteArray.nextInt();
child = new BaseNode[childSize];
for (int i = 0; i < childSize; ++i) {
child[i] = new Node<V>();
child[i].walkToLoad(byteArray, valueArray);
}
}
protected void walkToLoad(ObjectInput byteArray) throws IOException, ClassNotFoundException {
c = byteArray.readChar();
status = ARRAY_STATUS[byteArray.readInt()];
if (status == Status.WORD_END_3 || status == Status.WORD_MIDDLE_2) {
value = (V) byteArray.readObject();
}
int childSize = byteArray.readInt();
child = new BaseNode[childSize];
for (int i = 0; i < childSize; ++i) {
child[i] = new Node<V>();
child[i].walkToLoad(byteArray);
}
}
public enum Status {
/**
* 未指定,用于删除词条
*/
UNDEFINED_0,
/**
* 不是词语的结尾
*/
NOT_WORD_1,
/**
* 是个词语的结尾,并且还可以继续
*/
WORD_MIDDLE_2,
/**
* 是个词语的结尾,并且没有继续
*/
WORD_END_3,
}
public class TrieEntry extends AbstractMap.SimpleEntry<String, V> implements Comparable<TrieEntry> {
public TrieEntry(String key, V value) {
super(key, value);
}
@Override
public int compareTo(TrieEntry o) {
return getKey().compareTo(String.valueOf(o.getKey()));
}
}
@Override
public String toString() {
return "BaseNode{"
+ "child="
+ Arrays.toString(child)
+ ", status="
+ status
+ ", c="
+ c
+ ", value="
+ value
+ ", prefix='"
+ prefix
+ '\''
+ '}';
}
public void walkNode(Set<Map.Entry<String, V>> entrySet, Integer agentId, Set<Long> detectModelIds) {
if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) {
LoadRemoveService loadRemoveService = ContextUtils.getBean(LoadRemoveService.class);
logger.debug("agentId:{},detectModelIds:{},before:{}", agentId, detectModelIds, value.toString());
List natures = loadRemoveService.removeNatures((List) value, agentId, detectModelIds);
String name = this.prefix != null ? this.prefix + c : "" + c;
logger.debug("name:{},after:{},natures:{}", name, (List) value, natures);
entrySet.add(new TrieEntry(name, (V) natures));
}
}
/***
* walk limit
* @param sb
* @param entrySet
* @param limit
*/
public void walkLimit(StringBuilder sb, Set<Map.Entry<String, V>> entrySet, int limit, Integer agentId,
Set<Long> detectModelIds) {
Queue<BaseNode> queue = new ArrayDeque<>();
this.prefix = sb.toString();
queue.add(this);
while (!queue.isEmpty()) {
if (entrySet.size() >= limit) {
break;
}
BaseNode root = queue.poll();
if (root == null) {
continue;
}
root.walkNode(entrySet, agentId, detectModelIds);
if (root.child == null) {
continue;
}
String prefix = root.prefix + root.c;
for (BaseNode node : root.child) {
if (Objects.nonNull(node)) {
node.prefix = prefix;
queue.add(node);
}
}
}
}
}

View File

@@ -0,0 +1,393 @@
package com.hankcs.hanlp.dictionary;
import static com.hankcs.hanlp.utility.Predefine.logger;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.Collection;
import java.util.TreeMap;
/**
* 使用DoubleArrayTrie实现的核心词典
*/
public class CoreDictionary {
public static DoubleArrayTrie<Attribute> trie = new DoubleArrayTrie<Attribute>();
public static final String PATH = HanLP.Config.CoreDictionaryPath;
// 自动加载词典
static {
long start = System.currentTimeMillis();
if (!load(PATH)) {
throw new IllegalArgumentException("核心词典" + PATH + "加载失败");
} else {
logger.info(PATH + "加载成功," + trie.size() + "个词条,耗时" + (System.currentTimeMillis() - start) + "ms");
}
}
// 一些特殊的WORD_ID
public static final int NR_WORD_ID = getWordID(Predefine.TAG_PEOPLE);
public static final int NS_WORD_ID = getWordID(Predefine.TAG_PLACE);
public static final int NT_WORD_ID = getWordID(Predefine.TAG_GROUP);
public static final int T_WORD_ID = getWordID(Predefine.TAG_TIME);
public static final int X_WORD_ID = getWordID(Predefine.TAG_CLUSTER);
public static final int M_WORD_ID = getWordID(Predefine.TAG_NUMBER);
public static final int NX_WORD_ID = getWordID(Predefine.TAG_PROPER);
private static boolean load(String path) {
logger.info("核心词典开始加载:" + path);
if (loadDat(path)) {
return true;
}
TreeMap<String, Attribute> map = new TreeMap<String, Attribute>();
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
String line;
int totalFrequency = 0;
long start = System.currentTimeMillis();
while ((line = br.readLine()) != null) {
String[] param = line.split("\\s");
int natureCount = (param.length - 1) / 2;
Attribute attribute = new Attribute(natureCount);
for (int i = 0; i < natureCount; ++i) {
attribute.nature[i] = Nature.create(param[1 + 2 * i]);
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
attribute.totalFrequency += attribute.frequency[i];
}
map.put(param[0], attribute);
totalFrequency += attribute.totalFrequency;
}
logger.info(
"核心词典读入词条" + map.size() + " 全部频次" + totalFrequency + ",耗时" + (System.currentTimeMillis() - start)
+ "ms");
br.close();
trie.build(map);
logger.info("核心词典加载成功:" + trie.size() + "个词条,下面将写入缓存……");
try {
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(IOUtil.newOutputStream(path + Predefine.BIN_EXT)));
Collection<Attribute> attributeList = map.values();
out.writeInt(attributeList.size());
for (Attribute attribute : attributeList) {
out.writeInt(attribute.totalFrequency);
out.writeInt(attribute.nature.length);
for (int i = 0; i < attribute.nature.length; ++i) {
out.writeInt(attribute.nature[i].ordinal());
out.writeInt(attribute.frequency[i]);
}
}
trie.save(out);
out.writeInt(totalFrequency);
Predefine.setTotalFrequency(totalFrequency);
out.close();
} catch (Exception e) {
logger.warning("保存失败" + e);
return false;
}
} catch (FileNotFoundException e) {
logger.warning("核心词典" + path + "不存在!" + e);
return false;
} catch (IOException e) {
logger.warning("核心词典" + path + "读取错误!" + e);
return false;
}
return true;
}
/**
* 从磁盘加载双数组
*
* @param path
* @return
*/
static boolean loadDat(String path) {
try {
ByteArray byteArray = ByteArray.createByteArray(path + Predefine.BIN_EXT);
if (byteArray == null) {
return false;
}
int size = byteArray.nextInt();
Attribute[] attributes = new Attribute[size];
final Nature[] natureIndexArray = Nature.values();
for (int i = 0; i < size; ++i) {
// 第一个是全部频次,第二个是词性个数
int currentTotalFrequency = byteArray.nextInt();
int length = byteArray.nextInt();
attributes[i] = new Attribute(length);
attributes[i].totalFrequency = currentTotalFrequency;
for (int j = 0; j < length; ++j) {
attributes[i].nature[j] = natureIndexArray[byteArray.nextInt()];
attributes[i].frequency[j] = byteArray.nextInt();
}
}
if (!trie.load(byteArray, attributes)) {
return false;
}
int totalFrequency = 0;
if (byteArray.hasMore()) {
totalFrequency = byteArray.nextInt();
} else {
for (Attribute attribute : attributes) {
totalFrequency += attribute.totalFrequency;
}
}
Predefine.setTotalFrequency(totalFrequency);
} catch (Exception e) {
logger.warning("读取失败,问题发生在" + e);
return false;
}
return true;
}
/**
* 获取条目
*
* @param key
* @return
*/
public static Attribute get(String key) {
return trie.get(key);
}
/**
* 获取条目
*
* @param wordID
* @return
*/
public static Attribute get(int wordID) {
return trie.get(wordID);
}
/**
* 获取词频
*
* @param term
* @return
*/
public static int getTermFrequency(String term) {
Attribute attribute = get(term);
if (attribute == null) {
return 0;
}
return attribute.totalFrequency;
}
/**
* 是否包含词语
*
* @param key
* @return
*/
public static boolean contains(String key) {
return trie.get(key) != null;
}
/**
* 核心词典中的词属性
*/
public static class Attribute implements Serializable {
/**
* 词性列表
*/
public Nature[] nature;
/**
* 词性对应的词频
*/
public int[] frequency;
public int totalFrequency;
public String original = null;
public Attribute(int size) {
nature = new Nature[size];
frequency = new int[size];
}
public Attribute(Nature[] nature, int[] frequency) {
this.nature = nature;
this.frequency = frequency;
}
public Attribute(Nature nature, int frequency) {
this(1);
this.nature[0] = nature;
this.frequency[0] = frequency;
totalFrequency = frequency;
}
public Attribute(Nature[] nature, int[] frequency, int totalFrequency) {
this.nature = nature;
this.frequency = frequency;
this.totalFrequency = totalFrequency;
}
/**
* 使用单个词性默认词频1000构造
*
* @param nature
*/
public Attribute(Nature nature) {
this(nature, 1000);
}
public static Attribute create(String natureWithFrequency) {
try {
String[] param = natureWithFrequency.split(" ");
if (param.length % 2 != 0) {
return new Attribute(Nature.create(natureWithFrequency.trim()), 1); // 儿童锁
}
int natureCount = param.length / 2;
Attribute attribute = new Attribute(natureCount);
for (int i = 0; i < natureCount; ++i) {
attribute.nature[i] = Nature.create(param[2 * i]);
attribute.frequency[i] = Integer.parseInt(param[1 + 2 * i]);
attribute.totalFrequency += attribute.frequency[i];
}
return attribute;
} catch (Exception e) {
logger.warning("使用字符串" + natureWithFrequency + "创建词条属性失败!" + TextUtility.exceptionToString(e));
return null;
}
}
/**
* 从字节流中加载
*
* @param byteArray
* @param natureIndexArray
* @return
*/
public static Attribute create(ByteArray byteArray, Nature[] natureIndexArray) {
int currentTotalFrequency = byteArray.nextInt();
int length = byteArray.nextInt();
Attribute attribute = new Attribute(length);
attribute.totalFrequency = currentTotalFrequency;
for (int j = 0; j < length; ++j) {
attribute.nature[j] = natureIndexArray[byteArray.nextInt()];
attribute.frequency[j] = byteArray.nextInt();
}
return attribute;
}
/**
* 获取词性的词频
*
* @param nature 字符串词性
* @return 词频
* @deprecated 推荐使用Nature参数
*/
public int getNatureFrequency(String nature) {
try {
Nature pos = Nature.create(nature);
return getNatureFrequency(pos);
} catch (IllegalArgumentException e) {
return 0;
}
}
/**
* 获取词性的词频
*
* @param nature 词性
* @return 词频
*/
public int getNatureFrequency(final Nature nature) {
int i = 0;
for (Nature pos : this.nature) {
if (nature == pos) {
return frequency[i];
}
++i;
}
return 0;
}
/**
* 是否有某个词性
*
* @param nature
* @return
*/
public boolean hasNature(Nature nature) {
return getNatureFrequency(nature) > 0;
}
/**
* 是否有以某个前缀开头的词性
*
* @param prefix 词性前缀比如u会查询是否有ude, uzhe等等
* @return
*/
public boolean hasNatureStartsWith(String prefix) {
for (Nature n : nature) {
if (n.startsWith(prefix)) {
return true;
}
}
return false;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < nature.length; ++i) {
sb.append(nature[i]).append(' ').append(frequency[i]).append(' ');
}
return sb.toString();
}
public void save(DataOutputStream out) throws IOException {
out.writeInt(totalFrequency);
out.writeInt(nature.length);
for (int i = 0; i < nature.length; ++i) {
out.writeInt(nature[i].ordinal());
out.writeInt(frequency[i]);
}
}
}
/**
* 获取词语的ID
*
* @param a 词语
* @return ID, 如果不存在, 则返回-1
*/
public static int getWordID(String a) {
return CoreDictionary.trie.exactMatchSearch(a);
}
/**
* 热更新核心词典<br>
* 集群环境或其他IOAdapter需要自行删除缓存文件
*
* @return 是否成功
*/
public static boolean reload() {
String path = CoreDictionary.PATH;
IOUtil.deleteFile(path + Predefine.BIN_EXT);
return load(path);
}
}

View File

@@ -0,0 +1,341 @@
package com.hankcs.hanlp.seg;
import com.hankcs.hanlp.algorithm.Viterbi;
import com.hankcs.hanlp.collection.AhoCorasick.AhoCorasickDoubleArrayTrie;
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.dictionary.CoreDictionaryTransformMatrixDictionary;
import com.hankcs.hanlp.dictionary.other.CharType;
import com.hankcs.hanlp.seg.NShort.Path.AtomNode;
import com.hankcs.hanlp.seg.common.Graph;
import com.hankcs.hanlp.seg.common.Term;
import com.hankcs.hanlp.seg.common.Vertex;
import com.hankcs.hanlp.seg.common.WordNet;
import com.hankcs.hanlp.utility.TextUtility;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
public abstract class WordBasedSegment extends Segment {
public WordBasedSegment() {
}
protected static void generateWord(List<Vertex> linkedArray, WordNet wordNetOptimum) {
fixResultByRule(linkedArray);
wordNetOptimum.addAll(linkedArray);
}
protected static void fixResultByRule(List<Vertex> linkedArray) {
mergeContinueNumIntoOne(linkedArray);
changeDelimiterPOS(linkedArray);
splitMiddleSlashFromDigitalWords(linkedArray);
checkDateElements(linkedArray);
}
static void changeDelimiterPOS(List<Vertex> linkedArray) {
Iterator var1 = linkedArray.iterator();
while (true) {
Vertex vertex;
do {
if (!var1.hasNext()) {
return;
}
vertex = (Vertex) var1.next();
} while (!vertex.realWord.equals("") && !vertex.realWord.equals("") && !vertex.realWord.equals("-"));
vertex.confirmNature(Nature.w);
}
}
private static void splitMiddleSlashFromDigitalWords(List<Vertex> linkedArray) {
if (linkedArray.size() >= 2) {
ListIterator<Vertex> listIterator = linkedArray.listIterator();
Vertex next = (Vertex) listIterator.next();
for (Vertex current = next; listIterator.hasNext(); current = next) {
next = (Vertex) listIterator.next();
Nature currentNature = current.getNature();
if (currentNature == Nature.nx && (next.hasNature(Nature.q) || next.hasNature(Nature.n))) {
String[] param = current.realWord.split("-", 1);
if (param.length == 2 && TextUtility.isAllNum(param[0]) && TextUtility.isAllNum(param[1])) {
current = current.copy();
current.realWord = param[0];
current.confirmNature(Nature.m);
listIterator.previous();
listIterator.previous();
listIterator.set(current);
listIterator.next();
listIterator.add(Vertex.newPunctuationInstance("-"));
listIterator.add(Vertex.newNumberInstance(param[1]));
}
}
}
}
}
private static void checkDateElements(List<Vertex> linkedArray) {
if (linkedArray.size() >= 2) {
ListIterator<Vertex> listIterator = linkedArray.listIterator();
Vertex next = (Vertex) listIterator.next();
for (Vertex current = next; listIterator.hasNext(); current = next) {
next = (Vertex) listIterator.next();
if (TextUtility.isAllNum(current.realWord) || TextUtility.isAllChineseNum(current.realWord)) {
String nextWord = next.realWord;
if (nextWord.length() == 1 && "月日时分秒".contains(nextWord)
|| nextWord.length() == 2 && nextWord.equals("月份")) {
mergeDate(listIterator, next, current);
} else if (nextWord.equals("")) {
if (TextUtility.isYearTime(current.realWord)) {
mergeDate(listIterator, next, current);
} else {
current.confirmNature(Nature.m);
}
} else if (current.realWord.endsWith("")) {
current.confirmNature(Nature.t, true);
} else {
char[] tmpCharArray = current.realWord.toCharArray();
String lastChar = String.valueOf(tmpCharArray[tmpCharArray.length - 1]);
if (!"∶·././".contains(lastChar)) {
current.confirmNature(Nature.m, true);
} else if (current.realWord.length() > 1) {
char last = current.realWord.charAt(current.realWord.length() - 1);
current = Vertex.newNumberInstance(
current.realWord.substring(0, current.realWord.length() - 1));
listIterator.previous();
listIterator.previous();
listIterator.set(current);
listIterator.next();
listIterator.add(Vertex.newPunctuationInstance(String.valueOf(last)));
}
}
}
}
}
}
private static void mergeDate(ListIterator<Vertex> listIterator, Vertex next, Vertex current) {
current = Vertex.newTimeInstance(current.realWord + next.realWord);
listIterator.previous();
listIterator.previous();
listIterator.set(current);
listIterator.next();
listIterator.next();
listIterator.remove();
}
protected static List<Term> convert(List<Vertex> vertexList) {
return convert(vertexList, false);
}
protected static Graph generateBiGraph(WordNet wordNet) {
return wordNet.toGraph();
}
/**
* @deprecated
*/
private static List<AtomNode> atomSegment(String sSentence, int start, int end) {
if (end < start) {
throw new RuntimeException("start=" + start + " < end=" + end);
} else {
List<AtomNode> atomSegment = new ArrayList();
int pCur = 0;
StringBuilder sb = new StringBuilder();
char[] charArray = sSentence.substring(start, end).toCharArray();
int[] charTypeArray = new int[charArray.length];
for (int i = 0; i < charArray.length; ++i) {
char c = charArray[i];
charTypeArray[i] = CharType.get(c);
if (c == '.' && i < charArray.length - 1 && CharType.get(charArray[i + 1]) == 9) {
charTypeArray[i] = 9;
} else if (c == '.' && i < charArray.length - 1 && charArray[i + 1] >= '0' && charArray[i + 1] <= '9') {
charTypeArray[i] = 5;
} else if (charTypeArray[i] == 8) {
charTypeArray[i] = 5;
}
}
while (true) {
while (true) {
while (pCur < charArray.length) {
int nCurType = charTypeArray[pCur];
if (nCurType != 7 && nCurType != 10 && nCurType != 6 && nCurType != 17) {
if (pCur < charArray.length - 1 && (nCurType == 5 || nCurType == 9)) {
sb.delete(0, sb.length());
sb.append(charArray[pCur]);
boolean reachEnd = true;
while (pCur < charArray.length - 1) {
++pCur;
int nNextType = charTypeArray[pCur];
if (nNextType != nCurType) {
reachEnd = false;
break;
}
sb.append(charArray[pCur]);
}
atomSegment.add(new AtomNode(sb.toString(), nCurType));
if (reachEnd) {
++pCur;
}
} else {
atomSegment.add(new AtomNode(charArray[pCur], nCurType));
++pCur;
}
} else {
String single = String.valueOf(charArray[pCur]);
if (single.length() != 0) {
atomSegment.add(new AtomNode(single, nCurType));
}
++pCur;
}
}
return atomSegment;
}
}
}
}
private static void mergeContinueNumIntoOne(List<Vertex> linkedArray) {
if (linkedArray.size() >= 2) {
ListIterator<Vertex> listIterator = linkedArray.listIterator();
Vertex next = (Vertex) listIterator.next();
Vertex current = next;
while (true) {
while (listIterator.hasNext()) {
next = (Vertex) listIterator.next();
if (!TextUtility.isAllNum(current.realWord) && !TextUtility.isAllChineseNum(current.realWord)
|| !TextUtility.isAllNum(next.realWord) && !TextUtility.isAllChineseNum(next.realWord)) {
current = next;
} else {
current = Vertex.newNumberInstance(current.realWord + next.realWord);
listIterator.previous();
listIterator.previous();
listIterator.set(current);
listIterator.next();
listIterator.next();
listIterator.remove();
}
}
return;
}
}
}
protected void generateWordNet(final WordNet wordNetStorage) {
final char[] charArray = wordNetStorage.charArray;
DoubleArrayTrie.Searcher searcher = CoreDictionary.trie.getSearcher(charArray, 0);
while (searcher.next()) {
wordNetStorage.add(searcher.begin + 1, new Vertex(new String(charArray, searcher.begin, searcher.length),
(CoreDictionary.Attribute) searcher.value, searcher.index));
}
if (this.config.forceCustomDictionary) {
this.customDictionary.parseText(charArray, new AhoCorasickDoubleArrayTrie.IHit<CoreDictionary.Attribute>() {
public void hit(int begin, int end, CoreDictionary.Attribute value) {
wordNetStorage.add(begin + 1, new Vertex(new String(charArray, begin, end - begin), value));
}
});
}
LinkedList<Vertex>[] vertexes = wordNetStorage.getVertexes();
int i = 1;
while (true) {
while (i < vertexes.length) {
if (vertexes[i].isEmpty()) {
int j;
for (j = i + 1;
j < vertexes.length - 1 && (vertexes[j].isEmpty() || CharType.get(charArray[j - 1]) == 11);
++j) {
}
wordNetStorage.add(i, quickAtomSegment(charArray, i - 1, j - 1));
i = j;
} else {
i += ((Vertex) vertexes[i].getLast()).realWord.length();
}
}
return;
}
}
protected List<Term> decorateResultForIndexMode(List<Vertex> vertexList, WordNet wordNetAll) {
List<Term> termList = new LinkedList();
int line = 1;
ListIterator<Vertex> listIterator = vertexList.listIterator();
listIterator.next();
int length = vertexList.size() - 2;
for (int i = 0; i < length; ++i) {
Vertex vertex = (Vertex) listIterator.next();
Term termMain = convert(vertex);
//termList.add(termMain);
addTerms(termList, vertex, line - 1);
termMain.offset = line - 1;
if (vertex.realWord.length() > 2) {
label43:
for (int currentLine = line; currentLine < line + vertex.realWord.length(); ++currentLine) {
Iterator iterator = wordNetAll.descendingIterator(currentLine);
while (true) {
Vertex smallVertex;
do {
if (!iterator.hasNext()) {
continue label43;
}
smallVertex = (Vertex) iterator.next();
} while ((termMain.nature != Nature.mq || !smallVertex.hasNature(Nature.q))
&& smallVertex.realWord.length() < this.config.indexMode);
if (smallVertex != vertex
&& currentLine + smallVertex.realWord.length() <= line + vertex.realWord.length()) {
listIterator.add(smallVertex);
//Term termSub = convert(smallVertex);
//termSub.offset = currentLine - 1;
//termList.add(termSub);
addTerms(termList, smallVertex, currentLine - 1);
}
}
}
}
line += vertex.realWord.length();
}
return termList;
}
protected static void speechTagging(List<Vertex> vertexList) {
Viterbi.compute(vertexList, CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary);
}
protected void addTerms(List<Term> terms, Vertex vertex, int offset) {
for (int i = 0; i < vertex.attribute.nature.length; i++) {
Term term = new Term(vertex.realWord, vertex.attribute.nature[i]);
term.setFrequency(vertex.attribute.frequency[i]);
term.offset = offset;
terms.add(term);
}
}
}

View File

@@ -0,0 +1,69 @@
package com.hankcs.hanlp.seg.common;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.dictionary.CustomDictionary;
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class Term {
public String word;
public Nature nature;
public int offset;
public int frequency = 0;
public Term(String word, Nature nature) {
this.word = word;
this.nature = nature;
}
public Term(String word, Nature nature, int offset) {
this.word = word;
this.nature = nature;
this.offset = offset;
}
public Term(String word, Nature nature, int offset, int frequency) {
this.word = word;
this.nature = nature;
this.offset = offset;
this.frequency = frequency;
}
public int length() {
return this.word.length();
}
public int getFrequency() {
if (frequency > 0) {
return frequency;
}
String wordOri = word.toLowerCase();
CoreDictionary.Attribute attribute = HanlpHelper.getDynamicCustomDictionary().get(wordOri);
if (attribute == null) {
attribute = CoreDictionary.get(wordOri);
if (attribute == null) {
attribute = CustomDictionary.get(wordOri);
}
}
if (attribute != null && nature != null && attribute.hasNature(nature)) {
return attribute.getNatureFrequency(nature);
}
return attribute == null ? 0 : attribute.totalFrequency;
}
public boolean equals(Object obj) {
if (obj instanceof Term) {
Term term = (Term) obj;
if (this.nature == term.nature && this.word.equals(term.word)) {
return true;
}
}
return super.equals(obj);
}
}

View File

@@ -1,46 +0,0 @@
package com.tencent.supersonic.chat.config;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import lombok.ToString;
import java.util.List;
@Data
@ToString
public class ChatConfig {
/**
* database auto-increment primary key
*/
private Long id;
private Long modelId;
/**
* the chatDetailConfig about the model
*/
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the model
*/
private ChatAggConfigReq chatAggConfig;
private List<RecommendedQuestionReq> recommendedQuestions;
/**
* available status
*/
private StatusEnum status;
/**
* about createdBy, createdAt, updatedBy, updatedAt
*/
private RecordInfo recordInfo;
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.config;
import lombok.Data;
@Data
public class ChatConfigFilterInternal {
private Long id;
private Long modelId;
private Integer status;
}

View File

@@ -1,15 +1,18 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.Objects;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data;
import org.springframework.util.CollectionUtils;
@Data
public class Agent extends RecordInfo {
@@ -23,6 +26,7 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private String agentConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
@@ -45,4 +49,27 @@ public class Agent extends RecordInfo {
return enableSearch != null && enableSearch == 1;
}
public boolean containsAllModel(Set<Long> detectModelIds) {
return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L);
}
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
.collect(Collectors.toList());
}
public Set<Long> getModelIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet());
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
public enum AgentToolType {
NL2SQL_RULE,

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;

View File

@@ -0,0 +1,44 @@
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class DefaultSemanticConfig {
@Value("${semantic.url.prefix:http://localhost:8081}")
private String semanticUrl;
@Value("${searchByStruct.path:/api/semantic/query/struct}")
private String searchByStructPath;
@Value("${searchByStruct.path:/api/semantic/query/multiStruct}")
private String searchByMultiStructPath;
@Value("${searchByStruct.path:/api/semantic/query/sql}")
private String searchBySqlPath;
@Value("${searchByStruct.path:/api/semantic/query/queryDimValue}")
private String queryDimValuePath;
@Value("${fetchModelSchemaPath.path:/api/semantic/schema}")
private String fetchModelSchemaPath;
@Value("${fetchModelList.path:/api/semantic/schema/dimension/page}")
private String fetchDimensionPagePath;
@Value("${fetchModelList.path:/api/semantic/schema/metric/page}")
private String fetchMetricPagePath;
@Value("${fetchModelList.path:/api/semantic/schema/domain/list}")
private String fetchDomainListPath;
@Value("${fetchModelList.path:/api/semantic/schema/model/list}")
private String fetchModelListPath;
@Value("${explain.path:/api/semantic/query/explain}")
private String explainPath;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import java.util.List;
import lombok.AllArgsConstructor;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import java.util.List;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.headless.api.response.DimSchemaResp;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import lombok.Data;

View File

@@ -0,0 +1,40 @@
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
import java.io.FileNotFoundException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Data
@Configuration
@Slf4j
public class LocalFileConfig {
@Value("${dict.directory.latest:/data/dictionary/custom}")
private String dictDirectoryLatest;
@Value("${dict.directory.backup:./dict/backup}")
private String dictDirectoryBackup;
public String getDictDirectoryLatest() {
return getResourceDir() + dictDirectoryLatest;
}
public String getDictDirectoryBackup() {
return dictDirectoryBackup;
}
private String getResourceDir() {
String hanlpPropertiesPath = "";
try {
hanlpPropertiesPath = HanlpHelper.getHanlpPropertiesPath();
} catch (FileNotFoundException e) {
log.warn("getResourceDir, e:", e);
}
return hanlpPropertiesPath;
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.service.SysParameterService;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

View File

@@ -1,22 +1,14 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -24,6 +16,10 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
/**
* basic semantic correction functionality, offering common methods and an
@@ -32,23 +28,23 @@ import java.util.stream.Collectors;
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
return;
}
doCorrect(queryReq, semanticParseInfo);
doCorrect(queryContext, semanticParseInfo);
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
} catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
}
}
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Set<Long> modelIds) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
@@ -101,12 +97,12 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
}
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
List<SchemaElement> metrics = getMetricElements(modelIds);
List<SchemaElement> metrics = getMetricElements(queryContext, modelIds);
Map<String, String> metricToAggregate = metrics.stream()
.map(schemaElement -> {
@@ -131,8 +127,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Set<Long> modelIds) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
return semanticSchema.getMetrics(modelIds);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import lombok.extern.slf4j.Slf4j;
/**
@@ -11,7 +11,7 @@ import lombok.extern.slf4j.Slf4j;
public class FromCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String modelName = semanticParseInfo.getModel().getName();
String correctSql = SqlParserReplaceHelper
.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);

View File

@@ -1,21 +1,18 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "Group by" section in S2SQL.
@@ -24,19 +21,19 @@ import java.util.stream.Collectors;
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
addGroupByFields(semanticParseInfo);
addGroupByFields(queryContext, semanticParseInfo);
}
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
.flatMap(
@@ -77,15 +74,15 @@ public class GroupByCorrector extends BaseSemanticCorrector {
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
addAggregate(semanticParseInfo);
addAggregate(queryContext, semanticParseInfo);
}
private void addAggregate(SemanticParseInfo semanticParseInfo) {
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}
addAggregateToMetric(semanticParseInfo);
addAggregateToMetric(queryContext, semanticParseInfo);
}
}

View File

@@ -1,18 +1,14 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.Set;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@@ -24,20 +20,20 @@ import org.springframework.util.CollectionUtils;
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
addHaving(semanticParseInfo);
addHaving(queryContext, semanticParseInfo);
//add having expression filed to select
addHavingToSelect(semanticParseInfo);
}
private void addHaving(SemanticParseInfo semanticParseInfo) {
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
@@ -24,7 +24,7 @@ import org.springframework.util.CollectionUtils;
public class SchemaCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
correctAggFunction(semanticParseInfo);
@@ -34,7 +34,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
updateFieldValueByLinkingValue(semanticParseInfo);
correctFieldName(semanticParseInfo);
correctFieldName(queryContext, semanticParseInfo);
}
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
@@ -50,8 +50,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
sqlInfo.setCorrectS2SQL(replaceAlias);
}
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getModel().getModelIds());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
@@ -14,7 +14,7 @@ import org.springframework.util.CollectionUtils;
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
/**
* A semantic corrector checks validity of extracted semantic information and
* performs correction and optimization if needed.
*/
public interface SemanticCorrector {
void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
}

View File

@@ -1,20 +1,24 @@
package com.tencent.supersonic.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.sql.llm.S2SqlDateHelper;
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
@@ -23,13 +27,6 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Where" section in S2SQL.
*/
@@ -37,19 +34,19 @@ import java.util.stream.Collectors;
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(semanticParseInfo);
addDateIfNotExist(queryContext, semanticParseInfo);
parserDateDiffFunction(semanticParseInfo);
addQueryFilter(queryReq, semanticParseInfo);
addQueryFilter(queryContext, semanticParseInfo);
updateFieldValueByTechName(semanticParseInfo);
updateFieldValueByTechName(queryContext, semanticParseInfo);
}
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
@@ -72,11 +69,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
String currentDate = S2SqlDateHelper.getReferenceDate(semanticParseInfo.getModelId());
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getModelId());
if (StringUtils.isNotBlank(currentDate)) {
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
correctS2SQL = SqlParserAddHelper.addWhere(
@@ -100,8 +97,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
.collect(Collectors.joining(Constants.AND_UPPER));
}
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.google.common.base.Objects;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class DatabaseMapResult extends MapResult {
private SchemaElement schemaElement;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DatabaseMapResult that = (DatabaseMapResult) o;
return Objects.equal(name, that.name) && Objects.equal(schemaElement, that.schemaElement);
}
@Override
public int hashCode() {
return Objects.hashCode(name, schemaElement);
}
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.core.knowledge;
import java.util.List;
import lombok.Data;
@Data
public class DictConfig {
private Long modelId;
private List<DimValueInfo> dimValueInfoList;
}

View File

@@ -0,0 +1,31 @@
package com.tencent.supersonic.chat.core.knowledge;
public enum DictUpdateMode {
OFFLINE_FULL("OFFLINE_FULL"),
OFFLINE_MODEL("OFFLINE_MODEL"),
REALTIME_ADD("REALTIME_ADD"),
REALTIME_DELETE("REALTIME_DELETE"),
NOT_SUPPORT("NOT_SUPPORT");
private String value;
DictUpdateMode(String value) {
this.value = value;
}
public static DictUpdateMode of(String value) {
for (DictUpdateMode item : DictUpdateMode.values()) {
if (item.value.equalsIgnoreCase(value)) {
return item;
}
}
return DictUpdateMode.NOT_SUPPORT;
}
public String getValue() {
return value;
}
}

View File

@@ -0,0 +1,34 @@
package com.tencent.supersonic.chat.core.knowledge;
import java.util.Objects;
import lombok.Data;
import lombok.ToString;
/***
* word nature
*/
@Data
@ToString
public class DictWord {
private String word;
private String nature;
private String natureWithFrequency;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DictWord that = (DictWord) o;
return Objects.equals(word, that.word) && Objects.equals(natureWithFrequency, that.natureWithFrequency);
}
@Override
public int hashCode() {
return Objects.hash(word, natureWithFrequency);
}
}

View File

@@ -0,0 +1,38 @@
package com.tencent.supersonic.chat.core.knowledge.dictionary;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Dictionary Attribute Util
*/
public class DictionaryAttributeUtil {
public static CoreDictionary.Attribute getAttribute(CoreDictionary.Attribute old, CoreDictionary.Attribute add) {
Map<Nature, Integer> map = new HashMap<>();
IntStream.range(0, old.nature.length).boxed().forEach(i -> map.put(old.nature[i], old.frequency[i]));
IntStream.range(0, add.nature.length).boxed().forEach(i -> map.put(add.nature[i], add.frequency[i]));
List<Map.Entry<Nature, Integer>> list = new LinkedList<Map.Entry<Nature, Integer>>(map.entrySet());
Collections.sort(list, new Comparator<Map.Entry<Nature, Integer>>() {
public int compare(Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
return o2.getValue() - o1.getValue();
}
});
CoreDictionary.Attribute attribute = new CoreDictionary.Attribute(
list.stream().map(i -> i.getKey()).collect(Collectors.toList()).toArray(new Nature[0]),
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
list.stream().map(i -> i.getValue()).findFirst().get());
if (old.original != null || add.original != null) {
attribute.original = add.original != null ? add.original : old.original;
}
return attribute;
}
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.core.knowledge;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class DimValue2DictCommand {
private DictUpdateMode updateMode;
private List<Long> modelIds;
private Map<Long, List<Long>> modelAndDimPair = new HashMap<>();
}

View File

@@ -0,0 +1,31 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
import java.util.Date;
import java.util.Set;
import lombok.Data;
@Data
public class DimValueDictInfo {
private Long id;
private String name;
private String description;
private String command;
private TaskStatusEnum status;
private String createdBy;
private Date createdAt;
private Long elapsedMs;
private Set<Long> dimIds;
}

View File

@@ -0,0 +1,26 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import java.util.List;
import javax.validation.constraints.NotNull;
public class DimValueInfo {
/**
* metricId、DimensionId、domainId
*/
private Long itemId;
/**
* type: IntentionTypeEnum
* temporarily only supports dimension-related information
*/
@NotNull
private TypeEnums type = TypeEnums.DIMENSION;
private List<String> blackList;
private List<String> whiteList;
private List<String> ruleList;
private Boolean isDictInfo;
}

View File

@@ -0,0 +1,34 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.google.common.base.Objects;
import java.util.Map;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class EmbeddingResult extends MapResult {
private String id;
private double distance;
private Map<String, String> metadata;
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
EmbeddingResult that = (EmbeddingResult) o;
return Objects.equal(id, that.id);
}
@Override
public int hashCode() {
return Objects.hashCode(id);
}
}

View File

@@ -0,0 +1,56 @@
package com.tencent.supersonic.chat.core.knowledge;
import java.util.List;
public interface FileHandler {
/**
* backup files to a specific directory
* config: dict.directory.backup
*
* @param fileName
*/
void backupFile(String fileName);
/**
* create a directory
*
* @param path
*/
void createDir(String path);
Boolean existPath(String path);
/**
* write data to a specific file,
* config dir: dict.directory.latest
*
* @param data
* @param fileName
* @param append
*/
void writeFile(List<String> data, String fileName, Boolean append);
/**
* get the knowledge file root directory
*
* @return
*/
String getDictRootPath();
/**
* delete dictionary file
* automatic backup
*
* @param fileName
* @return
*/
Boolean deleteDictFile(String fileName);
/**
* delete files directly without backup
*
* @param fileName
*/
void deleteFile(String fileName);
}

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.hankcs.hanlp.corpus.io.IIOAdapter;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@Slf4j
public class HadoopFileIOAdapter implements IIOAdapter {
@Override
public InputStream open(String path) throws IOException {
log.info("open:{}", path);
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(URI.create(path), conf);
return fs.open(new Path(path));
}
@Override
public OutputStream create(String path) throws IOException {
log.info("create:{}", path);
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(URI.create(path), conf);
return fs.create(new Path(path));
}
}

View File

@@ -0,0 +1,44 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.google.common.base.Objects;
import java.util.List;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class HanlpMapResult extends MapResult {
private List<String> natures;
private int offset = 0;
private double similarity;
public HanlpMapResult(String name, List<String> natures, String detectWord) {
this.name = name;
this.natures = natures;
this.detectWord = detectWord;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
HanlpMapResult hanlpMapResult = (HanlpMapResult) o;
return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures);
}
@Override
public int hashCode() {
return Objects.hashCode(name, natures);
}
public void setOffset(int offset) {
this.offset = offset;
}
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.tencent.supersonic.chat.core.utils.NatureHelper;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Data
@Service
public class LoadRemoveService {
@Value("${mapper.remove.agentId:}")
private Integer mapperRemoveAgentId;
@Value("${mapper.remove.nature.prefix:}")
private String mapperRemoveNaturePrefix;
public List removeNatures(List value, Integer agentId, Set<Long> detectModelIds) {
if (CollectionUtils.isEmpty(value)) {
return value;
}
List<String> resultList = new ArrayList<>(value);
if (!CollectionUtils.isEmpty(detectModelIds)) {
resultList.removeIf(nature -> {
if (Objects.isNull(nature)) {
return false;
}
Long modelId = NatureHelper.getModelId(nature);
if (Objects.nonNull(modelId)) {
return !detectModelIds.contains(modelId);
}
return false;
});
}
if (Objects.nonNull(mapperRemoveAgentId)
&& mapperRemoveAgentId.equals(agentId)
&& StringUtils.isNotBlank(mapperRemoveNaturePrefix)) {
resultList.removeIf(nature -> {
if (Objects.isNull(nature)) {
return false;
}
return nature.startsWith(mapperRemoveNaturePrefix);
});
}
return resultList;
}
}

View File

@@ -0,0 +1,127 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.tencent.supersonic.chat.core.config.LocalFileConfig;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.StandardOpenOption;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class LocalFileHandler implements FileHandler {
private final LocalFileConfig localFileConfig;
public LocalFileHandler(LocalFileConfig localFileConfig) {
this.localFileConfig = localFileConfig;
}
@Override
public void backupFile(String fileName) {
String dictDirectoryBackup = localFileConfig.getDictDirectoryBackup();
if (!existPath(dictDirectoryBackup)) {
createDir(dictDirectoryBackup);
}
String source = localFileConfig.getDictDirectoryLatest() + "/" + fileName;
String target = dictDirectoryBackup + "/" + fileName;
Path sourcePath = Paths.get(source);
Path targetPath = Paths.get(target);
try {
Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING);
log.info("backupFile successfully! path:{}", targetPath.toAbsolutePath());
} catch (IOException e) {
log.info("Failed to copy file: " + e.getMessage());
}
}
@Override
public void createDir(String directoryPath) {
Path path = Paths.get(directoryPath);
try {
Files.createDirectories(path);
log.info("Directory created successfully!");
} catch (IOException e) {
log.info("Failed to create directory: " + e.getMessage());
}
}
@Override
public void deleteFile(String filePath) {
Path path = Paths.get(filePath);
try {
Files.delete(path);
log.info("File:{} deleted successfully!", getAbsolutePath(filePath));
} catch (IOException e) {
log.warn("Failed to delete file:{}, e:", getAbsolutePath(filePath), e);
}
}
@Override
public Boolean existPath(String pathStr) {
Path path = Paths.get(pathStr);
if (Files.exists(path)) {
log.info("path:{} exists!", getAbsolutePath(pathStr));
return true;
} else {
log.info("path:{} not exists!", getAbsolutePath(pathStr));
}
return false;
}
@Override
public void writeFile(List<String> lines, String fileName, Boolean append) {
String dictDirectoryLatest = localFileConfig.getDictDirectoryLatest();
if (!existPath(dictDirectoryLatest)) {
createDir(dictDirectoryLatest);
}
String filePath = dictDirectoryLatest + "/" + fileName;
if (existPath(filePath)) {
backupFile(fileName);
}
try (BufferedWriter writer = getWriter(filePath, append)) {
if (!CollectionUtils.isEmpty(lines)) {
for (String line : lines) {
writer.write(line);
writer.newLine();
}
}
log.info("File:{} written successfully!", getAbsolutePath(filePath));
} catch (IOException e) {
log.info("Failed to write file:{}, e:", getAbsolutePath(filePath), e);
}
}
public String getAbsolutePath(String path) {
return Paths.get(path).toAbsolutePath().toString();
}
@Override
public String getDictRootPath() {
return Paths.get(localFileConfig.getDictDirectoryLatest()).toAbsolutePath().toString();
}
@Override
public Boolean deleteDictFile(String fileName) {
backupFile(fileName);
deleteFile(localFileConfig.getDictDirectoryLatest() + "/" + fileName);
return true;
}
private BufferedWriter getWriter(String filePath, Boolean append) throws IOException {
if (append) {
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND);
}
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8);
}
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.core.knowledge;
import java.io.Serializable;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
public class MapResult implements Serializable {
protected String name;
protected String detectWord;
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.core.knowledge;
import java.io.Serializable;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
@Builder
public class ModelInfoStat implements Serializable {
private long modelCount;
private long metricModelCount;
private long dimensionModelCount;
private long dimensionValueModelCount;
}

View File

@@ -0,0 +1,396 @@
package com.tencent.supersonic.chat.core.knowledge;
import static com.hankcs.hanlp.utility.Predefine.logger;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
import com.hankcs.hanlp.dictionary.other.CharTable;
import com.hankcs.hanlp.seg.common.Term;
import com.hankcs.hanlp.utility.LexiconUtility;
import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility;
import com.tencent.supersonic.chat.core.knowledge.dictionary.DictionaryAttributeUtil;
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
public class MultiCustomDictionary extends DynamicCustomDictionary {
public static int MAX_SIZE = 10;
public static Boolean removeDuplicates = true;
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES = new ConcurrentHashMap<>();
private static boolean addToSuggesterTrie = true;
public MultiCustomDictionary() {
this(HanLP.Config.CustomDictionaryPath);
}
public MultiCustomDictionary(String... path) {
super(path);
}
/***
* load dictionary
* @param path
* @param defaultNature
* @param map
* @param customNatureCollector
* @param addToSuggeterTrie
* @return
*/
public static boolean load(String path, Nature defaultNature, TreeMap<String, CoreDictionary.Attribute> map,
LinkedHashSet<Nature> customNatureCollector, boolean addToSuggeterTrie) {
try {
String splitter = "\\s";
if (path.endsWith(".csv")) {
splitter = ",";
}
BufferedReader br = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
boolean firstLine = true;
while (true) {
String[] param;
do {
String line;
if ((line = br.readLine()) == null) {
br.close();
return true;
}
if (firstLine) {
line = IOUtil.removeUTF8BOM(line);
firstLine = false;
}
param = line.split(splitter);
} while (param[0].length() == 0);
if (HanLP.Config.Normalization) {
param[0] = CharTable.convert(param[0]);
}
int natureCount = (param.length - 1) / 2;
CoreDictionary.Attribute attribute;
boolean isLetters = isLetters(param[0]);
String original = null;
String word = getWordBySpace(param[0]);
if (isLetters) {
original = word;
word = word.toLowerCase();
}
if (natureCount == 0) {
attribute = new CoreDictionary.Attribute(defaultNature);
} else {
attribute = new CoreDictionary.Attribute(natureCount);
for (int i = 0; i < natureCount; ++i) {
attribute.nature[i] = LexiconUtility.convertStringToNature(param[1 + 2 * i],
customNatureCollector);
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
attribute.totalFrequency += attribute.frequency[i];
}
}
attribute.original = original;
if (removeDuplicates && map.containsKey(word)) {
attribute = DictionaryAttributeUtil.getAttribute(map.get(word), attribute);
}
map.put(word, attribute);
if (addToSuggeterTrie) {
SearchService.put(word, attribute);
}
for (int i = 0; i < attribute.nature.length; i++) {
Nature nature = attribute.nature[i];
PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
if (Objects.isNull(priorityQueue)) {
priorityQueue = new PriorityQueue<>(MAX_SIZE,
Comparator.comparingInt(Term::getFrequency).reversed());
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
}
Term term = new Term(word, nature);
term.setFrequency(attribute.frequency[i]);
if (!priorityQueue.contains(term) && priorityQueue.size() < MAX_SIZE) {
priorityQueue.add(term);
}
}
}
} catch (Exception var12) {
logger.severe("自定义词典" + path + "读取错误!" + var12);
return false;
}
}
public boolean load(String... path) {
this.path = path;
long start = System.currentTimeMillis();
if (!this.loadMainDictionary(path[0])) {
Predefine.logger.warning("自定义词典" + Arrays.toString(path) + "加载失败");
return false;
} else {
Predefine.logger.info(
"自定义词典加载成功:" + this.dat.size() + "个词条,耗时" + (System.currentTimeMillis() - start) + "ms");
this.path = path;
return true;
}
}
/***
* load main dictionary
* @param mainPath
* @param path
* @param dat
* @param isCache
* @param addToSuggestTrie
* @return
*/
public static boolean loadMainDictionary(String mainPath, String[] path,
DoubleArrayTrie<CoreDictionary.Attribute> dat, boolean isCache, boolean addToSuggestTrie) {
Predefine.logger.info("自定义词典开始加载:" + mainPath);
if (loadDat(mainPath, dat)) {
return true;
} else {
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
LinkedHashSet customNatureCollector = new LinkedHashSet();
try {
for (String p : path) {
Nature defaultNature = Nature.n;
File file = new File(p);
String fileName = file.getName();
int cut = fileName.lastIndexOf(32);
if (cut > 0) {
String nature = fileName.substring(cut + 1);
p = file.getParent() + File.separator + fileName.substring(0, cut);
try {
defaultNature = LexiconUtility.convertStringToNature(nature, customNatureCollector);
} catch (Exception var16) {
Predefine.logger.severe("配置文件【" + p + "】写错了!" + var16);
continue;
}
}
Predefine.logger.info("以默认词性[" + defaultNature + "]加载自定义词典" + p + "中……");
boolean success = load(p, defaultNature, map, customNatureCollector, addToSuggestTrie);
if (!success) {
Predefine.logger.warning("失败:" + p);
}
}
if (map.size() == 0) {
Predefine.logger.warning("没有加载到任何词条");
map.put("未##它", null);
}
logger.info("正在构建DoubleArrayTrie……");
dat.build(map);
if (addToSuggestTrie) {
// SearchService.save();
}
if (isCache) {
// 缓存成dat文件下次加载会快很多
logger.info("正在缓存词典为dat文件……");
// 缓存值文件
List<CoreDictionary.Attribute> attributeList = new LinkedList<CoreDictionary.Attribute>();
for (Map.Entry<String, CoreDictionary.Attribute> entry : map.entrySet()) {
attributeList.add(entry.getValue());
}
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(IOUtil.newOutputStream(mainPath + ".bin")));
if (customNatureCollector.isEmpty()) {
for (int i = Nature.begin.ordinal() + 1; i < Nature.values().length; ++i) {
Nature nature = Nature.values()[i];
if (Objects.nonNull(nature)) {
customNatureCollector.add(nature);
}
}
}
IOUtil.writeCustomNature(out, customNatureCollector);
out.writeInt(attributeList.size());
for (CoreDictionary.Attribute attribute : attributeList) {
attribute.save(out);
}
dat.save(out);
out.close();
}
} catch (FileNotFoundException var17) {
logger.severe("自定义词典" + mainPath + "不存在!" + var17);
return false;
} catch (IOException var18) {
logger.severe("自定义词典" + mainPath + "读取错误!" + var18);
return false;
} catch (Exception var19) {
logger.warning("自定义词典" + mainPath + "缓存失败!\n" + TextUtility.exceptionToString(var19));
}
return true;
}
}
public boolean loadMainDictionary(String mainPath) {
return loadMainDictionary(mainPath, this.path, this.dat, true, addToSuggesterTrie);
}
public static boolean loadDat(String path, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
return loadDat(path, HanLP.Config.CustomDictionaryPath, dat);
}
public static boolean loadDat(String path, String[] customDicPath, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
try {
if (HanLP.Config.CustomDictionaryAutoRefreshCache && isDicNeedUpdate(path, customDicPath)) {
return false;
} else {
ByteArray byteArray = ByteArray.createByteArray(path + ".bin");
if (byteArray == null) {
return false;
} else {
int size = byteArray.nextInt();
if (size < 0) {
while (true) {
++size;
if (size > 0) {
size = byteArray.nextInt();
break;
}
Nature.create(byteArray.nextString());
}
}
CoreDictionary.Attribute[] attributes = new CoreDictionary.Attribute[size];
Nature[] natureIndexArray = Nature.values();
for (int i = 0; i < size; ++i) {
int currentTotalFrequency = byteArray.nextInt();
int length = byteArray.nextInt();
attributes[i] = new CoreDictionary.Attribute(length);
attributes[i].totalFrequency = currentTotalFrequency;
for (int j = 0; j < length; ++j) {
attributes[i].nature[j] = natureIndexArray[byteArray.nextInt()];
attributes[i].frequency[j] = byteArray.nextInt();
}
}
if (!dat.load(byteArray, attributes)) {
return false;
} else {
return true;
}
}
}
} catch (Exception var11) {
logger.warning("读取失败,问题发生在" + TextUtility.exceptionToString(var11));
return false;
}
}
public static boolean isLetters(String str) {
char[] chars = str.toCharArray();
if (chars.length <= 1) {
return false;
}
for (int i = 0; i < chars.length; i++) {
if ((chars[i] >= 'A' && chars[i] <= 'Z')) {
return true;
}
}
return false;
}
public static boolean isLowerLetter(String str) {
char[] chars = str.toCharArray();
for (int i = 0; i < chars.length; i++) {
if ((chars[i] >= 'a' && chars[i] <= 'z')) {
return true;
}
}
return false;
}
public static String getWordBySpace(String word) {
if (word.contains(HanlpHelper.SPACE_SPILT)) {
return word.replace(HanlpHelper.SPACE_SPILT, " ");
}
return word;
}
public boolean reload() {
if (this.path != null && this.path.length != 0) {
IOUtil.deleteFile(this.path[0] + ".bin");
Boolean loadCacheOk = this.loadDat(this.path[0], this.path, this.dat);
if (!loadCacheOk) {
return this.loadMainDictionary(this.path[0], this.path, this.dat, true, addToSuggesterTrie);
}
}
return false;
}
public boolean insert(String word, String natureWithFrequency) {
if (word == null) {
return false;
} else {
if (HanLP.Config.Normalization) {
word = CharTable.convert(word);
}
CoreDictionary.Attribute att = natureWithFrequency == null ? new CoreDictionary.Attribute(Nature.nz, 1)
: CoreDictionary.Attribute.create(natureWithFrequency);
boolean isLetters = isLetters(word);
word = getWordBySpace(word);
String original = null;
if (isLetters) {
original = word;
word = word.toLowerCase();
}
if (att == null) {
return false;
} else if (this.dat.containsKey(word)) {
att.original = original;
att = DictionaryAttributeUtil.getAttribute(this.dat.get(word), att);
this.dat.set(word, att);
// return true;
} else {
if (this.trie == null) {
this.trie = new BinTrie();
}
att.original = original;
if (this.trie.containsKey(word)) {
att = DictionaryAttributeUtil.getAttribute(this.trie.get(word), att);
}
this.trie.put(word, att);
// return true;
}
if (addToSuggesterTrie) {
SearchService.put(word, att);
}
return true;
}
}
}

View File

@@ -0,0 +1,172 @@
package com.tencent.supersonic.chat.core.knowledge;
import com.hankcs.hanlp.collection.trie.bintrie.BaseNode;
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.CoreDictionary;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.core.knowledge.dictionary.DictionaryAttributeUtil;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service
@Slf4j
public class SearchService {
public static final int SEARCH_SIZE = 200;
private static BinTrie<List<String>> trie;
private static BinTrie<List<String>> suffixTrie;
static {
trie = new BinTrie<>();
suffixTrie = new BinTrie<>();
}
/***
* prefix Search
* @param key
* @return
*/
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
return prefixSearch(key, limit, agentId, trie, detectModelIds);
}
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId,
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
return result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
return new HanlpMapResult(name, entry.getValue(), key);
}
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.limit(SEARCH_SIZE)
.collect(Collectors.toList());
}
/***
* suffix Search
* @param key
* @return
*/
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
String reverseDetectSegment = StringUtils.reverse(key);
return suffixSearch(reverseDetectSegment, limit, agentId, suffixTrie, detectModelIds);
}
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId,
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
return result.stream().map(
entry -> {
String name = entry.getKey().replace("#", " ");
List<String> natures = entry.getValue().stream()
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
.collect(Collectors.toList());
name = StringUtils.reverse(name);
return new HanlpMapResult(name, natures, key);
}
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.limit(SEARCH_SIZE)
.collect(Collectors.toList());
}
private static Set<Map.Entry<String, List<String>>> prefixSearchLimit(String key, int limit,
BinTrie<List<String>> binTrie, Integer agentId, Set<Long> detectModelIds) {
key = key.toLowerCase();
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();
StringBuilder sb = new StringBuilder();
if (StringUtils.isNotBlank(key)) {
sb = new StringBuilder(key.substring(0, key.length() - 1));
}
BaseNode branch = binTrie;
char[] chars = key.toCharArray();
for (char aChar : chars) {
if (branch == null) {
return entrySet;
}
branch = branch.getChild(aChar);
}
if (branch == null) {
return entrySet;
}
branch.walkLimit(sb, entrySet, limit, agentId, detectModelIds);
return entrySet;
}
public static void clear() {
log.info("clear all trie");
trie = new BinTrie<>();
suffixTrie = new BinTrie<>();
}
public static void put(String key, CoreDictionary.Attribute attribute) {
trie.put(key, getValue(attribute.nature));
}
public static void loadSuffix(List<DictWord> suffixes) {
if (CollectionUtils.isEmpty(suffixes)) {
return;
}
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
for (DictWord suffix : suffixes) {
CoreDictionary.Attribute attributeNew = suffix.getNatureWithFrequency() == null
? new CoreDictionary.Attribute(Nature.nz, 1)
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
if (map.containsKey(suffix.getWord())) {
attributeNew = DictionaryAttributeUtil.getAttribute(map.get(suffix.getWord()), attributeNew);
}
map.put(suffix.getWord(), attributeNew);
}
for (Map.Entry<String, CoreDictionary.Attribute> stringAttributeEntry : map.entrySet()) {
putSuffix(stringAttributeEntry.getKey(), stringAttributeEntry.getValue());
}
}
public static void putSuffix(String key, CoreDictionary.Attribute attribute) {
Nature[] nature = attribute.nature;
suffixTrie.put(key, getValue(nature));
}
private static List<String> getValue(Nature[] nature) {
return Arrays.stream(nature).map(entry -> entry.toString()).collect(Collectors.toList());
}
public static void remove(DictWord dictWord, Nature[] natures) {
trie.remove(dictWord.getWord());
if (Objects.nonNull(natures) && natures.length > 0) {
trie.put(dictWord.getWord(), getValue(natures));
}
if (dictWord.getNature().contains(DictWordType.METRIC.getType()) || dictWord.getNature()
.contains(DictWordType.DIMENSION.getType())) {
suffixTrie.remove(dictWord.getWord());
}
}
public static List<String> getDimensionValue(DimensionValueReq dimensionValueReq) {
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT
+ dimensionValueReq.getElementID();
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
if (org.apache.commons.collections.CollectionUtils.isEmpty(terms)) {
return new ArrayList<>();
}
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,38 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.core.knowledge.DictWord;
import java.util.ArrayList;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* base word nature
*/
@Slf4j
public abstract class BaseWordBuilder {
public static final Long DEFAULT_FREQUENCY = 100000L;
public List<DictWord> getDictWords(List<SchemaElement> schemaElements) {
List<DictWord> dictWords = new ArrayList<>();
try {
dictWords = getDictWordsWithException(schemaElements);
} catch (Exception e) {
log.error("getWordNatureList error,", e);
}
return dictWords;
}
protected List<DictWord> getDictWordsWithException(List<SchemaElement> schemaElements) {
List<DictWord> dictWords = new ArrayList<>();
for (SchemaElement schemaElement : schemaElements) {
dictWords.addAll(doGet(schemaElement.getName(), schemaElement));
}
return dictWords;
}
protected abstract List<DictWord> doGet(String word, SchemaElement schemaElement);
}

View File

@@ -0,0 +1,63 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.core.knowledge.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* dimension word nature
*/
@Service
public class DimensionWordBuilder extends BaseWordBuilder {
@Value("${nlp.dimension.use.suffix:true}")
private boolean nlpDimensionUseSuffix = true;
@Override
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
result.add(getOnwWordNature(word, schemaElement, false));
result.addAll(getOnwWordNatureAlias(schemaElement, false));
if (nlpDimensionUseSuffix) {
String reverseWord = StringUtils.reverse(word);
if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) {
result.add(getOnwWordNature(reverseWord, schemaElement, true));
}
}
return result;
}
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
Long domainId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + domainId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.DIMENSION.getType();
if (isSuffix) {
nature = DictWordType.NATURE_SPILT + domainId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
}
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;
}
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
List<DictWord> dictWords = new ArrayList<>();
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return dictWords;
}
for (String alias : schemaElement.getAlias()) {
dictWords.add(getOnwWordNature(alias, schemaElement, false));
}
return dictWords;
}
}

View File

@@ -0,0 +1,44 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.core.knowledge.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* dimension value wordNature
*/
@Service
@Slf4j
public class EntityWordBuilder extends BaseWordBuilder {
@Override
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
if (Objects.isNull(schemaElement)) {
return result;
}
Long domain = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + domain + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.ENTITY.getType();
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
schemaElement.getAlias().stream().forEach(alias -> {
DictWord dictWordAlias = new DictWord();
dictWordAlias.setWord(alias);
dictWordAlias.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));
result.add(dictWordAlias);
});
}
return result;
}
}

View File

@@ -0,0 +1,63 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.core.knowledge.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* Metric DictWord
*/
@Service
public class MetricWordBuilder extends BaseWordBuilder {
@Value("${nlp.metric.use.suffix:true}")
private boolean nlpMetricUseSuffix = true;
@Override
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
result.add(getOnwWordNature(word, schemaElement, false));
result.addAll(getOnwWordNatureAlias(schemaElement, false));
if (nlpMetricUseSuffix) {
String reverseWord = StringUtils.reverse(word);
if (!word.equalsIgnoreCase(reverseWord)) {
result.add(getOnwWordNature(reverseWord, schemaElement, true));
}
}
return result;
}
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
Long modelId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.METRIC.getType();
if (isSuffix) {
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
}
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;
}
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
List<DictWord> dictWords = new ArrayList<>();
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return dictWords;
}
for (String alias : schemaElement.getAlias()) {
dictWords.add(getOnwWordNature(alias, schemaElement, false));
}
return dictWords;
}
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.core.knowledge.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.stereotype.Service;
/**
* model word nature
*/
@Service
@Slf4j
public class ModelWordBuilder extends BaseWordBuilder {
@Override
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
//modelName
DictWord dictWord = buildDictWord(word, schemaElement.getModel());
result.add(dictWord);
//alias
List<String> aliasList = schemaElement.getAlias();
if (CollectionUtils.isNotEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(buildDictWord(alias, schemaElement.getModel()));
}
}
return result;
}
private DictWord buildDictWord(String word, Long modelId) {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
String nature = DictWordType.NATURE_SPILT + modelId;
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;
}
}

View File

@@ -0,0 +1,40 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.core.knowledge.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* dimension value wordNature
*/
@Service
@Slf4j
public class ValueWordBuilder extends BaseWordBuilder {
@Override
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
if (Objects.nonNull(schemaElement) && !CollectionUtils.isEmpty(schemaElement.getAlias())) {
schemaElement.getAlias().stream().forEach(value -> {
DictWord dictWord = new DictWord();
Long modelId = schemaElement.getModel();
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId();
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
dictWord.setWord(value);
result.add(dictWord);
});
}
log.debug("ValueWordBuilder, result:{}", result);
return result;
}
}

View File

@@ -0,0 +1,26 @@
package com.tencent.supersonic.chat.core.knowledge.builder;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* DictWord Strategy Factory
*/
public class WordBuilderFactory {
private static Map<DictWordType, BaseWordBuilder> wordNatures = new ConcurrentHashMap<>();
static {
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
wordNatures.put(DictWordType.MODEL, new ModelWordBuilder());
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
}
public static BaseWordBuilder get(DictWordType strategyType) {
return wordNatures.get(strategyType);
}
}

View File

@@ -0,0 +1,67 @@
package com.tencent.supersonic.chat.core.knowledge.semantic;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
@SneakyThrows
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
if (cacheEnable) {
return modelSchemaCache.get(String.valueOf(ids), () -> {
List<ModelSchemaResp> data = doFetchModelSchema(ids);
modelSchemaCache.put(String.valueOf(ids), data);
return data;
});
}
List<ModelSchemaResp> data = doFetchModelSchema(ids);
return data;
}
@Override
public ModelSchema getModelSchema(Long model, Boolean cacheEnable) {
List<Long> ids = new ArrayList<>();
ids.add(model);
List<ModelSchemaResp> modelSchemaResps = fetchModelSchema(ids, cacheEnable);
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
Optional<ModelSchemaResp> modelSchemaResp = modelSchemaResps.stream()
.filter(d -> d.getId().equals(model)).findFirst();
if (modelSchemaResp.isPresent()) {
ModelSchemaResp modelSchema = modelSchemaResp.get();
return ModelSchemaBuilder.build(modelSchema);
}
}
return null;
}
@Override
public List<ModelSchema> getModelSchema() {
return getModelSchema(new ArrayList<>());
}
@Override
public List<ModelSchema> getModelSchema(List<Long> ids) {
List<ModelSchema> domainSchemaList = new ArrayList<>();
for (ModelSchemaResp resp : fetchModelSchema(ids, true)) {
domainSchemaList.add(ModelSchemaBuilder.build(resp));
}
return domainSchemaList;
}
protected abstract List<ModelSchemaResp> doFetchModelSchema(List<Long> ids);
}

View File

@@ -0,0 +1,120 @@
package com.tencent.supersonic.chat.core.knowledge.semantic;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.request.ModelSchemaFilterReq;
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.request.PageMetricReq;
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.DimensionResp;
import com.tencent.supersonic.headless.api.response.DomainResp;
import com.tencent.supersonic.headless.api.response.ExplainResp;
import com.tencent.supersonic.headless.api.response.MetricResp;
import com.tencent.supersonic.headless.api.response.ModelResp;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import com.tencent.supersonic.headless.api.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.SchemaService;
import java.util.HashMap;
import java.util.List;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
private SchemaService schemaService;
private DimensionService dimensionService;
private MetricService metricService;
private QueryService queryService;
@SneakyThrows
@Override
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
queryS2SQLReq.setSql(queryStructReq.getCorrectS2SQL());
queryS2SQLReq.setModelIds(queryStructReq.getModelIdSet());
queryS2SQLReq.setVariables(new HashMap<>());
return queryByS2SQL(queryS2SQLReq, user);
}
queryService = ContextUtils.getBean(QueryService.class);
return queryService.queryByStructWithAuth(queryStructReq, user);
}
@Override
public QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) {
try {
queryService = ContextUtils.getBean(QueryService.class);
return queryService.queryByMultiStruct(queryMultiStructReq, user);
} catch (Exception e) {
log.info("queryByMultiStruct has an exception:{}", e);
}
return null;
}
@Override
@SneakyThrows
public QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user) {
queryService = ContextUtils.getBean(QueryService.class);
Object object = queryService.queryBySql(queryS2SQLReq, user);
return JsonUtil.toObject(JsonUtil.toString(object), QueryResultWithSchemaResp.class);
}
@Override
@SneakyThrows
public QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
queryService = ContextUtils.getBean(QueryService.class);
return queryService.queryDimValue(queryDimValueReq, user);
}
@Override
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
filter.setModelIds(ids);
schemaService = ContextUtils.getBean(SchemaService.class);
User user = User.getFakeUser();
return schemaService.fetchModelSchema(filter, user);
}
@Override
public List<DomainResp> getDomainList(User user) {
schemaService = ContextUtils.getBean(SchemaService.class);
return schemaService.getDomainList(user);
}
@Override
public List<ModelResp> getModelList(AuthType authType, Long domainId, User user) {
schemaService = ContextUtils.getBean(SchemaService.class);
return schemaService.getModelList(user, authType, domainId);
}
@Override
public <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception {
queryService = ContextUtils.getBean(QueryService.class);
return queryService.explain(explainSqlReq, user);
}
@Override
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
dimensionService = ContextUtils.getBean(DimensionService.class);
return dimensionService.queryDimension(pageDimensionCmd);
}
@Override
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricReq, User user) {
metricService = ContextUtils.getBean(MetricService.class);
return metricService.queryMetric(pageMetricReq, user);
}
}

View File

@@ -0,0 +1,153 @@
package com.tencent.supersonic.chat.core.knowledge.semantic;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class ModelSchemaBuilder {
public static ModelSchema build(ModelSchemaResp resp) {
ModelSchema modelSchema = new ModelSchema();
SchemaElement model = SchemaElement.builder()
.model(resp.getId())
.id(resp.getId())
.name(resp.getName())
.bizName(resp.getBizName())
.type(SchemaElementType.MODEL)
.alias(SchemaItem.getAliasList(resp.getAlias()))
.build();
modelSchema.setModel(model);
modelSchema.setModelRelas(resp.getModelRelas());
Set<SchemaElement> metrics = new HashSet<>();
for (MetricSchemaResp metric : resp.getMetrics()) {
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
SchemaElement metricToAdd = SchemaElement.builder()
.model(resp.getId())
.id(metric.getId())
.name(metric.getName())
.bizName(metric.getBizName())
.type(SchemaElementType.METRIC)
.useCnt(metric.getUseCnt())
.alias(alias)
.relatedSchemaElements(getRelateSchemaElement(metric))
.defaultAgg(metric.getDefaultAgg())
.build();
metrics.add(metricToAdd);
}
modelSchema.getMetrics().addAll(metrics);
Set<SchemaElement> dimensions = new HashSet<>();
Set<SchemaElement> dimensionValues = new HashSet<>();
Set<SchemaElement> tags = new HashSet<>();
for (DimSchemaResp dim : resp.getDimensions()) {
List<String> alias = SchemaItem.getAliasList(dim.getAlias());
Set<String> dimValueAlias = new HashSet<>();
List<DimValueMap> dimValueMaps = dim.getDimValueMaps();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
if (!CollectionUtils.isEmpty(dimValueMaps)) {
for (DimValueMap dimValueMap : dimValueMaps) {
if (Strings.isNotEmpty(dimValueMap.getBizName())) {
dimValueAlias.add(dimValueMap.getBizName());
}
if (!CollectionUtils.isEmpty(dimValueMap.getAlias())) {
dimValueAlias.addAll(dimValueMap.getAlias());
}
SchemaValueMap schemaValueMap = new SchemaValueMap();
BeanUtils.copyProperties(dimValueMap, schemaValueMap);
schemaValueMaps.add(schemaValueMap);
}
}
SchemaElement dimToAdd = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.DIMENSION)
.useCnt(dim.getUseCnt())
.alias(alias)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(dimToAdd);
SchemaElement dimValueToAdd = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.VALUE)
.useCnt(dim.getUseCnt())
.alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0]))))
.build();
dimensionValues.add(dimValueToAdd);
if (dim.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.TAG)
.useCnt(dim.getUseCnt())
.alias(alias)
.schemaValueMaps(schemaValueMaps)
.build();
tags.add(tagToAdd);
}
}
modelSchema.getDimensions().addAll(dimensions);
modelSchema.getDimensionValues().addAll(dimensionValues);
modelSchema.getTags().addAll(tags);
DimSchemaResp dim = resp.getPrimaryKey();
if (dim != null) {
SchemaElement entity = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.ENTITY)
.useCnt(dim.getUseCnt())
.alias(dim.getEntityAlias())
.build();
modelSchema.setEntity(entity);
}
return modelSchema;
}
private static List<RelatedSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {
RelateDimension relateDimension = metricSchemaResp.getRelateDimension();
if (relateDimension == null || CollectionUtils.isEmpty(relateDimension.getDrillDownDimensions())) {
return Lists.newArrayList();
}
return relateDimension.getDrillDownDimensions().stream().map(dimension -> {
RelatedSchemaElement relateSchemaElement = new RelatedSchemaElement();
BeanUtils.copyProperties(dimension, relateSchemaElement);
return relateSchemaElement;
}).collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,313 @@
package com.tencent.supersonic.chat.core.knowledge.semantic;
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
import com.alibaba.fastjson.JSON;
import com.github.pagehelper.PageInfo;
import com.google.gson.Gson;
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
import com.tencent.supersonic.auth.api.authentication.constant.UserConstants;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.config.DefaultSemanticConfig;
import com.tencent.supersonic.common.pojo.ResultData;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.pojo.enums.ReturnCode;
import com.tencent.supersonic.common.pojo.exception.CommonException;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.S2ThreadContext;
import com.tencent.supersonic.common.util.ThreadContext;
import com.tencent.supersonic.headless.api.request.ModelSchemaFilterReq;
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.request.PageMetricReq;
import com.tencent.supersonic.headless.api.response.DimensionResp;
import com.tencent.supersonic.headless.api.response.DomainResp;
import com.tencent.supersonic.headless.api.response.ExplainResp;
import com.tencent.supersonic.headless.api.response.MetricResp;
import com.tencent.supersonic.headless.api.response.ModelResp;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import com.tencent.supersonic.headless.api.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import java.net.URI;
import java.net.URL;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
private S2ThreadContext s2ThreadContext;
private AuthenticationConfig authenticationConfig;
private ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>> structTypeRef =
new ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>>() {
};
private ParameterizedTypeReference<ResultData<ExplainResp>> explainTypeRef =
new ParameterizedTypeReference<ResultData<ExplainResp>>() {
};
@Override
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
queryS2SQLReq.setSql(queryStructReq.getCorrectS2SQL());
queryS2SQLReq.setModelIds(queryStructReq.getModelIdSet());
queryS2SQLReq.setVariables(new HashMap<>());
return queryByS2SQL(queryS2SQLReq, user);
}
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
return searchByRestTemplate(
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchByStructPath(),
new Gson().toJson(queryStructReq));
}
@Override
public QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
return searchByRestTemplate(
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchByMultiStructPath(),
new Gson().toJson(queryMultiStructReq));
}
@Override
public QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user) {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchBySqlPath(),
new Gson().toJson(queryS2SQLReq));
}
public QueryResultWithSchemaResp searchByRestTemplate(String url, String jsonReq) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
fillToken(headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
log.info("url:{},searchByRestTemplate:{}", url, entity.getBody());
ResultData<QueryResultWithSchemaResp> responseBody;
try {
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<ResultData<QueryResultWithSchemaResp>> responseEntity = restTemplate.exchange(
requestUrl, HttpMethod.POST, entity, structTypeRef);
responseBody = responseEntity.getBody();
log.info("ApiResponse<QueryResultWithColumns> responseBody:{}", responseBody);
QueryResultWithSchemaResp schemaResp = new QueryResultWithSchemaResp();
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
QueryResultWithSchemaResp data = responseBody.getData();
schemaResp.setColumns(data.getColumns());
schemaResp.setResultList(data.getResultList());
schemaResp.setSql(data.getSql());
schemaResp.setQueryAuthorization(data.getQueryAuthorization());
return schemaResp;
}
} catch (Exception e) {
throw new RuntimeException("search headless interface error,url:" + url, e);
}
throw new CommonException(responseBody.getCode(), responseBody.getMsg());
}
@Override
public QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl()
+ defaultSemanticConfig.getQueryDimValuePath(),
new Gson().toJson(queryDimValueReq));
}
@Override
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
HttpHeaders headers = new HttpHeaders();
headers.set(UserConstants.INTERNAL, TRUE_LOWER);
headers.setContentType(MediaType.APPLICATION_JSON);
fillToken(headers);
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
String fetchModelSchemaPath = defaultSemanticConfig.getFetchModelSchemaPath();
URI requestUrl = UriComponentsBuilder.fromHttpUrl(semanticUrl + fetchModelSchemaPath)
.build().encode().toUri();
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
filter.setModelIds(ids);
ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>> responseTypeRef =
new ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>>() {
};
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(filter), headers);
try {
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<ResultData<List<ModelSchemaResp>>> responseEntity = restTemplate.exchange(
requestUrl, HttpMethod.POST, entity, responseTypeRef);
ResultData<List<ModelSchemaResp>> responseBody = responseEntity.getBody();
log.debug("ApiResponse<fetchModelSchema> responseBody:{}", responseBody);
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
List<ModelSchemaResp> data = responseBody.getData();
return data;
}
} catch (Exception e) {
throw new RuntimeException("fetchModelSchema interface error", e);
}
throw new RuntimeException("fetchModelSchema interface error");
}
@Override
public List<DomainResp> getDomainList(User user) {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
Object domainDescListObject = fetchHttpResult(
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchDomainListPath(),
null, HttpMethod.GET);
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), DomainResp.class);
}
@Override
public List<ModelResp> getModelList(AuthType authType, Long domainId, User user) {
if (domainId == null) {
domainId = 0L;
}
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
String url = String.format("%s?domainId=%s&authType=%s",
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchModelListPath(),
domainId, authType.toString());
Object domainDescListObject = fetchHttpResult(url, null, HttpMethod.GET);
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), ModelResp.class);
}
@Override
public <T> ExplainResp explain(ExplainSqlReq<T> explainResp, User user) throws Exception {
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
String explainPath = defaultSemanticConfig.getExplainPath();
URL url = new URL(new URL(semanticUrl), explainPath);
return explain(url.toString(), JsonUtil.toString(explainResp));
}
public ExplainResp explain(String url, String jsonReq) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
fillToken(headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
log.info("url:{},explain:{}", url, entity.getBody());
ResultData<ExplainResp> responseBody;
try {
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<ResultData<ExplainResp>> responseEntity = restTemplate.exchange(
requestUrl, HttpMethod.POST, entity, explainTypeRef);
log.info("ApiResponse<ExplainResp> responseBody:{}", responseEntity);
responseBody = responseEntity.getBody();
if (Objects.nonNull(responseBody.getData())) {
return responseBody.getData();
}
return null;
} catch (Exception e) {
throw new RuntimeException("explain interface error,url:" + url, e);
}
}
public Object fetchHttpResult(String url, String bodyJson, HttpMethod httpMethod) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
fillToken(headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
ParameterizedTypeReference<ResultData<Object>> responseTypeRef =
new ParameterizedTypeReference<ResultData<Object>>() {
};
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(bodyJson), headers);
try {
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<ResultData<Object>> responseEntity = restTemplate.exchange(requestUrl,
httpMethod, entity, responseTypeRef);
ResultData<Object> responseBody = responseEntity.getBody();
log.debug("ApiResponse<fetchModelSchema> responseBody:{}", responseBody);
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
Object data = responseBody.getData();
return data;
}
} catch (Exception e) {
throw new RuntimeException("fetchModelSchema interface error", e);
}
throw new RuntimeException("fetchModelSchema interface error");
}
public void fillToken(HttpHeaders headers) {
s2ThreadContext = ContextUtils.getBean(S2ThreadContext.class);
authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class);
ThreadContext threadContext = s2ThreadContext.get();
if (Objects.nonNull(threadContext) && Strings.isNotEmpty(threadContext.getToken())) {
if (Objects.nonNull(authenticationConfig) && Strings.isNotEmpty(
authenticationConfig.getTokenHttpHeaderKey())) {
headers.set(authenticationConfig.getTokenHttpHeaderKey(), threadContext.getToken());
}
} else {
log.debug("threadContext is null:{}", Objects.isNull(threadContext));
}
}
@Override
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user) {
String body = JsonUtil.toString(pageMetricCmd);
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
log.info("url:{}", defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath());
Object dimensionListObject = fetchHttpResult(
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath(),
body, HttpMethod.POST);
LinkedHashMap map = (LinkedHashMap) dimensionListObject;
PageInfo<Object> metricDescObjectPageInfo = generatePageInfo(map);
PageInfo<MetricResp> metricDescPageInfo = new PageInfo<>();
BeanUtils.copyProperties(metricDescObjectPageInfo, metricDescPageInfo);
metricDescPageInfo.setList(metricDescPageInfo.getList());
return metricDescPageInfo;
}
@Override
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
String body = JsonUtil.toString(pageDimensionCmd);
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
Object dimensionListObject = fetchHttpResult(
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchDimensionPagePath(),
body, HttpMethod.POST);
LinkedHashMap map = (LinkedHashMap) dimensionListObject;
PageInfo<Object> dimensionDescObjectPageInfo = generatePageInfo(map);
PageInfo<DimensionResp> dimensionDescPageInfo = new PageInfo<>();
BeanUtils.copyProperties(dimensionDescObjectPageInfo, dimensionDescPageInfo);
dimensionDescPageInfo.setList(dimensionDescPageInfo.getList());
return dimensionDescPageInfo;
}
private PageInfo<Object> generatePageInfo(LinkedHashMap map) {
PageInfo<Object> pageInfo = new PageInfo<>();
pageInfo.setList((List<Object>) map.get(LIST_LOWER));
Integer total = (Integer) map.get(TOTAL_LOWER);
pageInfo.setTotal(total);
Integer pageSize = (Integer) map.get(PAGESIZE_LOWER);
pageInfo.setPageSize(pageSize);
pageInfo.setPages((int) Math.ceil((double) total / pageSize));
return pageInfo;
}
}

View File

@@ -0,0 +1,63 @@
package com.tencent.supersonic.chat.core.knowledge.semantic;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.request.PageMetricReq;
import com.tencent.supersonic.headless.api.response.DomainResp;
import com.tencent.supersonic.headless.api.response.DimensionResp;
import com.tencent.supersonic.headless.api.response.ExplainResp;
import com.tencent.supersonic.headless.api.response.MetricResp;
import com.tencent.supersonic.headless.api.response.ModelResp;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import com.tencent.supersonic.headless.api.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import java.util.List;
/**
* A semantic layer provides a simplified and consistent view of data from multiple sources.
* It abstracts away the complexity of the underlying data sources and provides a unified view
* of the data that is easier to understand and use.
* <p>
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
* Implementations of this interface should provide concrete implementations that interact with the
* underlying data sources and return results in a consistent format. Or it can be implemented
* as proxy to a remote semantic service.
* </p>
*/
public interface SemanticInterpreter {
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
List<ModelSchema> getModelSchema();
List<ModelSchema> getModelSchema(List<Long> ids);
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
List<DomainResp> getDomainList(User user);
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
}

View File

@@ -1,14 +1,12 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -68,10 +66,10 @@ public abstract class BaseMapper implements SchemaMapper {
}
}
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID,
SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
if (Objects.isNull(modelSchema)) {
return null;
}

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.utils.NatureHelper;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -102,7 +102,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
}
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
List<T> matches = new ArrayList<>();

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -33,14 +32,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
private OptimizationConfig optimizationConfig;
@Autowired
private MapperHelper mapperHelper;
@Autowired
private SchemaService schemaService;
private List<SchemaElement> allElements;
@Override
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
Set<Long> detectModelIds) {
this.allElements = getSchemaElements();
this.allElements = getSchemaElements(queryContext);
return super.match(queryContext, terms, detectModelIds);
}
@@ -62,7 +59,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
if (StringUtils.isBlank(detectSegment)) {
return;
}
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
Double metricDimensionThresholdConfig = getThreshold(queryContext);
@@ -90,10 +87,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
}
}
private List<SchemaElement> getSchemaElements() {
private List<SchemaElement> getSchemaElements(QueryContext queryContext) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
allElements.addAll(queryContext.getSemanticSchema().getDimensions());
allElements.addAll(queryContext.getSemanticSchema().getMetrics());
return allElements;
}

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.alibaba.fastjson.JSONObject;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -43,7 +43,8 @@ public class EmbeddingMapper extends BaseMapper {
continue;
}
long modelId = Long.parseLong(modelIdStr);
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;
}

View File

@@ -1,28 +1,28 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.server.listener.MetaEmbeddingListener;
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* EmbeddingMatchStrategy uses vector database to perform
@@ -35,6 +35,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private EmbeddingConfig embeddingConfig;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Override
@@ -86,7 +89,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
.build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
embeddingConfig.getMetaCollectionName(), retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;

View File

@@ -1,21 +1,19 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
/**
* A mapper capable of converting the VALUE of entity dimension values into ID types.
*/
@@ -30,7 +28,7 @@ public class EntityMapper extends BaseMapper {
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(modelId);
SchemaElement entity = getEntity(modelId, queryContext);
if (entity == null || entity.getId() == null) {
continue;
}
@@ -66,9 +64,9 @@ public class EntityMapper extends BaseMapper {
return false;
}
private SchemaElement getEntity(Long modelId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
private SchemaElement getEntity(Long modelId, QueryContext queryContext) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.knowledge.SearchService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
@@ -69,8 +69,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
agentId,
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
oneDetectionMaxSize, agentId, detectModelIds).stream()

View File

@@ -1,16 +1,16 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
import com.tencent.supersonic.chat.core.utils.NatureHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -35,7 +35,7 @@ public class KeywordMapper extends BaseMapper {
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext.getMapInfo(), terms);
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms);
//2.database Match
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
@@ -44,7 +44,7 @@ public class KeywordMapper extends BaseMapper {
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
}
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, SchemaMapInfo schemaMap,
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
List<Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
@@ -65,7 +65,8 @@ public class KeywordMapper extends BaseMapper {
continue;
}
Long elementID = NatureHelper.getElementID(nature);
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
SchemaElement element = getSchemaElement(modelId, elementType, elementID,
queryContext.getSemanticSchema());
if (element == null) {
continue;
}
@@ -81,7 +82,7 @@ public class KeywordMapper extends BaseMapper {
.detectWord(hanlpMapResult.getDetectWord())
.build();
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
}
}
}

View File

@@ -1,12 +1,11 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.utils.NatureHelper;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
@@ -83,15 +82,12 @@ public class MapperHelper {
detectSegment.length());
}
public Set<Long> getModelIds(QueryReq request) {
public Set<Long> getModelIds(QueryReq request, Agent agent) {
Long modelId = request.getModelId();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
Set<Long> detectModelIds = agent.getModelIds(null);
//contains all
if (agentService.containsAllModel(detectModelIds)) {
if (agent.containsAllModel(detectModelIds)) {
if (Objects.nonNull(modelId) && modelId > 0) {
Set<Long> result = new HashSet<>();
result.add(modelId);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import java.util.List;
import java.util.Map;
import java.util.Set;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import java.util.Objects;
import lombok.Builder;

View File

@@ -1,16 +1,12 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -27,8 +23,7 @@ public class ModelClusterMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
@@ -46,7 +41,7 @@ public class ModelClusterMapper implements SchemaMapper {
}
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
SemanticSchema semanticSchema) {
SemanticSchema semanticSchema) {
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import java.io.Serializable;

View File

@@ -1,22 +1,21 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class QueryFilterMapper implements SchemaMapper {

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
/**
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
* in user queries. It matches the query text against the knowledge base.
*/
public interface SchemaMapper {
void map(QueryContext queryContext);
}

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List;
import java.util.Map;
import java.util.Objects;

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.chat.parser;
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionPromptGenerator;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Objects;

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
/**
* LLMProxy encapsulates functions performed by LLMs so that multiple
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
* could be used.
*/
public interface LLMProxy {
boolean isSkip(QueryContext queryContext);
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
FunctionResp requestFunction(FunctionReq functionReq);
}

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.parser;
package com.tencent.supersonic.chat.core.parser;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;

View File

@@ -1,30 +1,25 @@
package com.tencent.supersonic.chat.parser;
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
/**
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
@@ -40,14 +35,14 @@ public class QueryTypeParser implements SemanticParser {
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL
semanticQuery.initS2Sql(user);
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
// 2.set queryType
QueryType queryType = getQueryType(semanticQuery);
QueryType queryType = getQueryType(queryContext, semanticQuery);
semanticQuery.getParseInfo().setQueryType(queryType);
}
}
private QueryType getQueryType(SemanticQuery semanticQuery) {
private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
@@ -55,13 +50,13 @@ public class QueryTypeParser implements SemanticParser {
}
//1. entity queryType
Set<Long> modelIds = parseInfo.getModel().getModelIds();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
//If all the fields in the SELECT statement are of tag type.
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
if (CollectionUtils.isNotEmpty(whereFields)) {
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
@@ -77,7 +72,6 @@ public class QueryTypeParser implements SemanticParser {
}
//2. metric queryType
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.parser;
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
/**
* A semantic parser understands user queries and extracts semantic information.
* It could leverage either rule-based or LLM-based approach to identify query intent
* and extract related semantic items from the query.
*/
public interface SemanticParser {
void parse(QueryContext queryContext, ChatContext chatContext);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin;
package com.tencent.supersonic.chat.core.parser.plugin;
public enum ParseMode {

View File

@@ -1,22 +1,22 @@
package com.tencent.supersonic.chat.parser.plugin;
package com.tencent.supersonic.chat.core.parser.plugin;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
@@ -75,7 +75,7 @@ public abstract class PluginParser implements SemanticParser {
}
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
return PluginManager.getPluginAgentCanSupport(queryContext);
}
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.parser.plugin.embedding;
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin.embedding;
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin.embedding;
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;

View File

@@ -1,29 +1,26 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
*/
@@ -45,19 +42,17 @@ public class FunctionCallParser extends PluginParser {
@Override
public PluginRecallResult recallPlugin(QueryContext queryContext) {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
FunctionResp functionResp = functionCall(queryContext);
if (skipFunction(functionResp)) {
return null;
}
log.info("requestFunction result:{}", functionResp.getToolSelection());
String toolSelection = functionResp.getToolSelection();
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
if (!pluginOptional.isPresent()) {
Plugin plugin = queryContext.getNameToPlugin().get(toolSelection);
if (Objects.isNull(plugin)) {
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
return null;
}
Plugin plugin = pluginOptional.get();
plugin.setParseMode(ParseMode.FUNCTION_CALL);
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
if (pluginResolveResult.getLeft()) {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import lombok.Data;

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import com.tencent.supersonic.chat.parser.sql.llm.InputFormat;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.core.parser.sql.llm.InputFormat;
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
import java.util.List;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import lombok.Builder;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.plugin.function;
package com.tencent.supersonic.chat.core.parser.plugin.function;
import lombok.Data;
import java.util.List;

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.parser.sql.llm;
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.sql.llm;
package com.tencent.supersonic.chat.core.parser.sql.llm;
import java.util.ArrayList;
import java.util.HashMap;

View File

@@ -1,28 +1,26 @@
package com.tencent.supersonic.chat.parser.sql.llm;
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.NL2SQLTool;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.agent.AgentToolType;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
import java.util.ArrayList;
@@ -49,10 +47,6 @@ public class LLMRequestService {
@Autowired
private LLMParserConfig llmParserConfig;
@Autowired
private AgentService agentService;
@Autowired
private SchemaService schemaService;
@Autowired
private OptimizationConfig optimizationConfig;
public boolean isSkip(QueryContext queryCtx) {
@@ -66,9 +60,9 @@ public class LLMRequestService {
return false;
}
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.NL2SQL_LLM);
if (agentService.containsAllModel(distinctModelIds)) {
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) {
Set<Long> distinctModelIds = queryCtx.getAgent().getModelIds(AgentToolType.NL2SQL_LLM);
if (queryCtx.getAgent().containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}
ModelResolver modelResolver = ComponentFactory.getModelResolver();
@@ -77,13 +71,13 @@ public class LLMRequestService {
return ModelCluster.build(modelCluster);
}
public NL2SQLTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
List<NL2SQLTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.NL2SQL_LLM);
public NL2SQLTool getParserTool(QueryContext queryCtx, Set<Long> modelIdSet) {
Agent agent = queryCtx.getAgent();
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> modelIds = tool.getModelIds();
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
if (agent.containsAllModel(new HashSet<>(modelIds))) {
return true;
}
for (Long modelId : modelIdSet) {
@@ -127,7 +121,7 @@ public class LLMRequestService {
}
llmReq.setLinking(linking);
String currentDate = S2SqlDateHelper.getReferenceDate(firstModelId);
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, firstModelId);
if (StringUtils.isEmpty(currentDate)) {
currentDate = DateUtils.getBeforeDate(0);
}
@@ -143,7 +137,7 @@ public class LLMRequestService {
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
Set<String> results = getTopNFieldNames(queryCtx, modelCluster, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
@@ -187,7 +181,7 @@ public class LLMRequestService {
}
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
.getMatchedElements(modelCluster.getKey());
@@ -210,14 +204,15 @@ public class LLMRequestService {
return new ArrayList<>(valueMatches);
}
protected Map<Long, String> getItemIdToName(ModelCluster modelCluster) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, ModelCluster modelCluster) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(ModelCluster modelCluster, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
private Set<String> getTopNFieldNames(QueryContext queryCtx, ModelCluster modelCluster,
LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
@@ -235,7 +230,7 @@ public class LLMRequestService {
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
.getMatchedElements(modelCluster.getKey());
if (CollectionUtils.isEmpty(matchedElements)) {

Some files were not shown because too many files have changed in this diff Show More