mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:08:56 +00:00
(improvement)(chat) Split chat into three modules: server, api, and core. (#594)
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
package com.hankcs.hanlp.collection.trie.bintrie;
|
||||
|
||||
import com.hankcs.hanlp.corpus.io.ByteArray;
|
||||
import com.tencent.supersonic.chat.core.knowledge.LoadRemoveService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
||||
public abstract class BaseNode<V> implements Comparable<BaseNode> {
|
||||
|
||||
/**
|
||||
* 状态数组,方便读取的时候用
|
||||
*/
|
||||
static final Status[] ARRAY_STATUS = Status.values();
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(BaseNode.class);
|
||||
/**
|
||||
* 子节点
|
||||
*/
|
||||
protected BaseNode[] child;
|
||||
/**
|
||||
* 节点状态
|
||||
*/
|
||||
protected Status status;
|
||||
/**
|
||||
* 节点代表的字符
|
||||
*/
|
||||
protected char c;
|
||||
/**
|
||||
* 节点代表的值
|
||||
*/
|
||||
protected V value;
|
||||
|
||||
protected String prefix = null;
|
||||
|
||||
public BaseNode<V> transition(String path, int begin) {
|
||||
BaseNode<V> cur = this;
|
||||
for (int i = begin; i < path.length(); ++i) {
|
||||
cur = cur.getChild(path.charAt(i));
|
||||
if (cur == null || cur.status == Status.UNDEFINED_0) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
public BaseNode<V> transition(char[] path, int begin) {
|
||||
BaseNode<V> cur = this;
|
||||
for (int i = begin; i < path.length; ++i) {
|
||||
cur = cur.getChild(path[i]);
|
||||
if (cur == null || cur.status == Status.UNDEFINED_0) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
/**
|
||||
* 转移状态
|
||||
*
|
||||
* @param path
|
||||
* @return
|
||||
*/
|
||||
public BaseNode<V> transition(char path) {
|
||||
BaseNode<V> cur = this;
|
||||
cur = cur.getChild(path);
|
||||
if (cur == null || cur.status == Status.UNDEFINED_0) {
|
||||
return null;
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加子节点
|
||||
*
|
||||
* @return true-新增了节点 false-修改了现有节点
|
||||
*/
|
||||
protected abstract boolean addChild(BaseNode node);
|
||||
|
||||
/**
|
||||
* 是否含有子节点
|
||||
*
|
||||
* @param c 子节点的char
|
||||
* @return 是否含有
|
||||
*/
|
||||
protected boolean hasChild(char c) {
|
||||
return getChild(c) != null;
|
||||
}
|
||||
|
||||
protected char getChar() {
|
||||
return c;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取子节点
|
||||
*
|
||||
* @param c 子节点的char
|
||||
* @return 子节点
|
||||
*/
|
||||
public abstract BaseNode getChild(char c);
|
||||
|
||||
/**
|
||||
* 获取节点对应的值
|
||||
*
|
||||
* @return 值
|
||||
*/
|
||||
public final V getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置节点对应的值
|
||||
*
|
||||
* @param value 值
|
||||
*/
|
||||
public final void setValue(V value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(BaseNode other) {
|
||||
return compareTo(other.getChar());
|
||||
}
|
||||
|
||||
/**
|
||||
* 重载,与字符的比较
|
||||
*
|
||||
* @param other
|
||||
* @return
|
||||
*/
|
||||
public int compareTo(char other) {
|
||||
if (this.c > other) {
|
||||
return 1;
|
||||
}
|
||||
if (this.c < other) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取节点的成词状态
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Status getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
protected void walk(StringBuilder sb, Set<Map.Entry<String, V>> entrySet) {
|
||||
sb.append(c);
|
||||
if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) {
|
||||
entrySet.add(new TrieEntry(sb.toString(), value));
|
||||
}
|
||||
if (child == null) {
|
||||
return;
|
||||
}
|
||||
for (BaseNode node : child) {
|
||||
if (node == null) {
|
||||
continue;
|
||||
}
|
||||
node.walk(new StringBuilder(sb.toString()), entrySet);
|
||||
}
|
||||
}
|
||||
|
||||
protected void walkToSave(DataOutputStream out) throws IOException {
|
||||
out.writeChar(c);
|
||||
out.writeInt(status.ordinal());
|
||||
int childSize = 0;
|
||||
if (child != null) {
|
||||
childSize = child.length;
|
||||
}
|
||||
out.writeInt(childSize);
|
||||
if (child == null) {
|
||||
return;
|
||||
}
|
||||
for (BaseNode node : child) {
|
||||
node.walkToSave(out);
|
||||
}
|
||||
}
|
||||
|
||||
protected void walkToSave(ObjectOutput out) throws IOException {
|
||||
out.writeChar(c);
|
||||
out.writeInt(status.ordinal());
|
||||
if (status == Status.WORD_END_3 || status == Status.WORD_MIDDLE_2) {
|
||||
out.writeObject(value);
|
||||
}
|
||||
int childSize = 0;
|
||||
if (child != null) {
|
||||
childSize = child.length;
|
||||
}
|
||||
out.writeInt(childSize);
|
||||
if (child == null) {
|
||||
return;
|
||||
}
|
||||
for (BaseNode node : child) {
|
||||
node.walkToSave(out);
|
||||
}
|
||||
}
|
||||
|
||||
protected void walkToLoad(ByteArray byteArray, _ValueArray<V> valueArray) {
|
||||
c = byteArray.nextChar();
|
||||
status = ARRAY_STATUS[byteArray.nextInt()];
|
||||
if (status == Status.WORD_END_3 || status == Status.WORD_MIDDLE_2) {
|
||||
value = valueArray.nextValue();
|
||||
}
|
||||
int childSize = byteArray.nextInt();
|
||||
child = new BaseNode[childSize];
|
||||
for (int i = 0; i < childSize; ++i) {
|
||||
child[i] = new Node<V>();
|
||||
child[i].walkToLoad(byteArray, valueArray);
|
||||
}
|
||||
}
|
||||
|
||||
protected void walkToLoad(ObjectInput byteArray) throws IOException, ClassNotFoundException {
|
||||
c = byteArray.readChar();
|
||||
status = ARRAY_STATUS[byteArray.readInt()];
|
||||
if (status == Status.WORD_END_3 || status == Status.WORD_MIDDLE_2) {
|
||||
value = (V) byteArray.readObject();
|
||||
}
|
||||
int childSize = byteArray.readInt();
|
||||
child = new BaseNode[childSize];
|
||||
for (int i = 0; i < childSize; ++i) {
|
||||
child[i] = new Node<V>();
|
||||
child[i].walkToLoad(byteArray);
|
||||
}
|
||||
}
|
||||
|
||||
public enum Status {
|
||||
/**
|
||||
* 未指定,用于删除词条
|
||||
*/
|
||||
UNDEFINED_0,
|
||||
/**
|
||||
* 不是词语的结尾
|
||||
*/
|
||||
NOT_WORD_1,
|
||||
/**
|
||||
* 是个词语的结尾,并且还可以继续
|
||||
*/
|
||||
WORD_MIDDLE_2,
|
||||
/**
|
||||
* 是个词语的结尾,并且没有继续
|
||||
*/
|
||||
WORD_END_3,
|
||||
}
|
||||
|
||||
public class TrieEntry extends AbstractMap.SimpleEntry<String, V> implements Comparable<TrieEntry> {
|
||||
|
||||
public TrieEntry(String key, V value) {
|
||||
super(key, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(TrieEntry o) {
|
||||
return getKey().compareTo(String.valueOf(o.getKey()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "BaseNode{"
|
||||
+ "child="
|
||||
+ Arrays.toString(child)
|
||||
+ ", status="
|
||||
+ status
|
||||
+ ", c="
|
||||
+ c
|
||||
+ ", value="
|
||||
+ value
|
||||
+ ", prefix='"
|
||||
+ prefix
|
||||
+ '\''
|
||||
+ '}';
|
||||
}
|
||||
|
||||
public void walkNode(Set<Map.Entry<String, V>> entrySet, Integer agentId, Set<Long> detectModelIds) {
|
||||
if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) {
|
||||
LoadRemoveService loadRemoveService = ContextUtils.getBean(LoadRemoveService.class);
|
||||
logger.debug("agentId:{},detectModelIds:{},before:{}", agentId, detectModelIds, value.toString());
|
||||
List natures = loadRemoveService.removeNatures((List) value, agentId, detectModelIds);
|
||||
String name = this.prefix != null ? this.prefix + c : "" + c;
|
||||
logger.debug("name:{},after:{},natures:{}", name, (List) value, natures);
|
||||
entrySet.add(new TrieEntry(name, (V) natures));
|
||||
}
|
||||
}
|
||||
|
||||
/***
|
||||
* walk limit
|
||||
* @param sb
|
||||
* @param entrySet
|
||||
* @param limit
|
||||
*/
|
||||
public void walkLimit(StringBuilder sb, Set<Map.Entry<String, V>> entrySet, int limit, Integer agentId,
|
||||
Set<Long> detectModelIds) {
|
||||
Queue<BaseNode> queue = new ArrayDeque<>();
|
||||
this.prefix = sb.toString();
|
||||
queue.add(this);
|
||||
while (!queue.isEmpty()) {
|
||||
if (entrySet.size() >= limit) {
|
||||
break;
|
||||
}
|
||||
BaseNode root = queue.poll();
|
||||
if (root == null) {
|
||||
continue;
|
||||
}
|
||||
root.walkNode(entrySet, agentId, detectModelIds);
|
||||
if (root.child == null) {
|
||||
continue;
|
||||
}
|
||||
String prefix = root.prefix + root.c;
|
||||
for (BaseNode node : root.child) {
|
||||
if (Objects.nonNull(node)) {
|
||||
node.prefix = prefix;
|
||||
queue.add(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,393 @@
|
||||
package com.hankcs.hanlp.dictionary;
|
||||
|
||||
|
||||
import static com.hankcs.hanlp.utility.Predefine.logger;
|
||||
|
||||
import com.hankcs.hanlp.HanLP;
|
||||
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
|
||||
import com.hankcs.hanlp.corpus.io.ByteArray;
|
||||
import com.hankcs.hanlp.corpus.io.IOUtil;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.utility.Predefine;
|
||||
import com.hankcs.hanlp.utility.TextUtility;
|
||||
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.Serializable;
|
||||
import java.util.Collection;
|
||||
import java.util.TreeMap;
|
||||
|
||||
/**
|
||||
* 使用DoubleArrayTrie实现的核心词典
|
||||
*/
|
||||
public class CoreDictionary {
|
||||
|
||||
public static DoubleArrayTrie<Attribute> trie = new DoubleArrayTrie<Attribute>();
|
||||
|
||||
public static final String PATH = HanLP.Config.CoreDictionaryPath;
|
||||
|
||||
// 自动加载词典
|
||||
static {
|
||||
long start = System.currentTimeMillis();
|
||||
if (!load(PATH)) {
|
||||
throw new IllegalArgumentException("核心词典" + PATH + "加载失败");
|
||||
} else {
|
||||
logger.info(PATH + "加载成功," + trie.size() + "个词条,耗时" + (System.currentTimeMillis() - start) + "ms");
|
||||
}
|
||||
}
|
||||
|
||||
// 一些特殊的WORD_ID
|
||||
public static final int NR_WORD_ID = getWordID(Predefine.TAG_PEOPLE);
|
||||
public static final int NS_WORD_ID = getWordID(Predefine.TAG_PLACE);
|
||||
public static final int NT_WORD_ID = getWordID(Predefine.TAG_GROUP);
|
||||
public static final int T_WORD_ID = getWordID(Predefine.TAG_TIME);
|
||||
public static final int X_WORD_ID = getWordID(Predefine.TAG_CLUSTER);
|
||||
public static final int M_WORD_ID = getWordID(Predefine.TAG_NUMBER);
|
||||
public static final int NX_WORD_ID = getWordID(Predefine.TAG_PROPER);
|
||||
|
||||
private static boolean load(String path) {
|
||||
logger.info("核心词典开始加载:" + path);
|
||||
if (loadDat(path)) {
|
||||
return true;
|
||||
}
|
||||
TreeMap<String, Attribute> map = new TreeMap<String, Attribute>();
|
||||
BufferedReader br = null;
|
||||
try {
|
||||
br = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
|
||||
String line;
|
||||
int totalFrequency = 0;
|
||||
long start = System.currentTimeMillis();
|
||||
while ((line = br.readLine()) != null) {
|
||||
String[] param = line.split("\\s");
|
||||
int natureCount = (param.length - 1) / 2;
|
||||
Attribute attribute = new Attribute(natureCount);
|
||||
for (int i = 0; i < natureCount; ++i) {
|
||||
attribute.nature[i] = Nature.create(param[1 + 2 * i]);
|
||||
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
|
||||
attribute.totalFrequency += attribute.frequency[i];
|
||||
}
|
||||
map.put(param[0], attribute);
|
||||
totalFrequency += attribute.totalFrequency;
|
||||
}
|
||||
logger.info(
|
||||
"核心词典读入词条" + map.size() + " 全部频次" + totalFrequency + ",耗时" + (System.currentTimeMillis() - start)
|
||||
+ "ms");
|
||||
br.close();
|
||||
trie.build(map);
|
||||
logger.info("核心词典加载成功:" + trie.size() + "个词条,下面将写入缓存……");
|
||||
try {
|
||||
DataOutputStream out = new DataOutputStream(
|
||||
new BufferedOutputStream(IOUtil.newOutputStream(path + Predefine.BIN_EXT)));
|
||||
Collection<Attribute> attributeList = map.values();
|
||||
out.writeInt(attributeList.size());
|
||||
for (Attribute attribute : attributeList) {
|
||||
out.writeInt(attribute.totalFrequency);
|
||||
out.writeInt(attribute.nature.length);
|
||||
for (int i = 0; i < attribute.nature.length; ++i) {
|
||||
out.writeInt(attribute.nature[i].ordinal());
|
||||
out.writeInt(attribute.frequency[i]);
|
||||
}
|
||||
}
|
||||
trie.save(out);
|
||||
out.writeInt(totalFrequency);
|
||||
Predefine.setTotalFrequency(totalFrequency);
|
||||
out.close();
|
||||
} catch (Exception e) {
|
||||
logger.warning("保存失败" + e);
|
||||
return false;
|
||||
}
|
||||
} catch (FileNotFoundException e) {
|
||||
logger.warning("核心词典" + path + "不存在!" + e);
|
||||
return false;
|
||||
} catch (IOException e) {
|
||||
logger.warning("核心词典" + path + "读取错误!" + e);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从磁盘加载双数组
|
||||
*
|
||||
* @param path
|
||||
* @return
|
||||
*/
|
||||
static boolean loadDat(String path) {
|
||||
try {
|
||||
ByteArray byteArray = ByteArray.createByteArray(path + Predefine.BIN_EXT);
|
||||
if (byteArray == null) {
|
||||
return false;
|
||||
}
|
||||
int size = byteArray.nextInt();
|
||||
Attribute[] attributes = new Attribute[size];
|
||||
final Nature[] natureIndexArray = Nature.values();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
// 第一个是全部频次,第二个是词性个数
|
||||
int currentTotalFrequency = byteArray.nextInt();
|
||||
int length = byteArray.nextInt();
|
||||
attributes[i] = new Attribute(length);
|
||||
attributes[i].totalFrequency = currentTotalFrequency;
|
||||
for (int j = 0; j < length; ++j) {
|
||||
attributes[i].nature[j] = natureIndexArray[byteArray.nextInt()];
|
||||
attributes[i].frequency[j] = byteArray.nextInt();
|
||||
}
|
||||
}
|
||||
if (!trie.load(byteArray, attributes)) {
|
||||
return false;
|
||||
}
|
||||
int totalFrequency = 0;
|
||||
if (byteArray.hasMore()) {
|
||||
totalFrequency = byteArray.nextInt();
|
||||
} else {
|
||||
for (Attribute attribute : attributes) {
|
||||
totalFrequency += attribute.totalFrequency;
|
||||
}
|
||||
}
|
||||
Predefine.setTotalFrequency(totalFrequency);
|
||||
} catch (Exception e) {
|
||||
logger.warning("读取失败,问题发生在" + e);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取条目
|
||||
*
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static Attribute get(String key) {
|
||||
return trie.get(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取条目
|
||||
*
|
||||
* @param wordID
|
||||
* @return
|
||||
*/
|
||||
public static Attribute get(int wordID) {
|
||||
return trie.get(wordID);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取词频
|
||||
*
|
||||
* @param term
|
||||
* @return
|
||||
*/
|
||||
public static int getTermFrequency(String term) {
|
||||
Attribute attribute = get(term);
|
||||
if (attribute == null) {
|
||||
return 0;
|
||||
}
|
||||
return attribute.totalFrequency;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否包含词语
|
||||
*
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static boolean contains(String key) {
|
||||
return trie.get(key) != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 核心词典中的词属性
|
||||
*/
|
||||
public static class Attribute implements Serializable {
|
||||
|
||||
/**
|
||||
* 词性列表
|
||||
*/
|
||||
public Nature[] nature;
|
||||
/**
|
||||
* 词性对应的词频
|
||||
*/
|
||||
public int[] frequency;
|
||||
|
||||
public int totalFrequency;
|
||||
public String original = null;
|
||||
|
||||
|
||||
public Attribute(int size) {
|
||||
nature = new Nature[size];
|
||||
frequency = new int[size];
|
||||
}
|
||||
|
||||
public Attribute(Nature[] nature, int[] frequency) {
|
||||
this.nature = nature;
|
||||
this.frequency = frequency;
|
||||
}
|
||||
|
||||
public Attribute(Nature nature, int frequency) {
|
||||
this(1);
|
||||
this.nature[0] = nature;
|
||||
this.frequency[0] = frequency;
|
||||
totalFrequency = frequency;
|
||||
}
|
||||
|
||||
public Attribute(Nature[] nature, int[] frequency, int totalFrequency) {
|
||||
this.nature = nature;
|
||||
this.frequency = frequency;
|
||||
this.totalFrequency = totalFrequency;
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用单个词性,默认词频1000构造
|
||||
*
|
||||
* @param nature
|
||||
*/
|
||||
public Attribute(Nature nature) {
|
||||
this(nature, 1000);
|
||||
}
|
||||
|
||||
public static Attribute create(String natureWithFrequency) {
|
||||
try {
|
||||
String[] param = natureWithFrequency.split(" ");
|
||||
if (param.length % 2 != 0) {
|
||||
return new Attribute(Nature.create(natureWithFrequency.trim()), 1); // 儿童锁
|
||||
}
|
||||
int natureCount = param.length / 2;
|
||||
Attribute attribute = new Attribute(natureCount);
|
||||
for (int i = 0; i < natureCount; ++i) {
|
||||
attribute.nature[i] = Nature.create(param[2 * i]);
|
||||
attribute.frequency[i] = Integer.parseInt(param[1 + 2 * i]);
|
||||
attribute.totalFrequency += attribute.frequency[i];
|
||||
}
|
||||
return attribute;
|
||||
} catch (Exception e) {
|
||||
logger.warning("使用字符串" + natureWithFrequency + "创建词条属性失败!" + TextUtility.exceptionToString(e));
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从字节流中加载
|
||||
*
|
||||
* @param byteArray
|
||||
* @param natureIndexArray
|
||||
* @return
|
||||
*/
|
||||
public static Attribute create(ByteArray byteArray, Nature[] natureIndexArray) {
|
||||
int currentTotalFrequency = byteArray.nextInt();
|
||||
int length = byteArray.nextInt();
|
||||
Attribute attribute = new Attribute(length);
|
||||
attribute.totalFrequency = currentTotalFrequency;
|
||||
for (int j = 0; j < length; ++j) {
|
||||
attribute.nature[j] = natureIndexArray[byteArray.nextInt()];
|
||||
attribute.frequency[j] = byteArray.nextInt();
|
||||
}
|
||||
|
||||
return attribute;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取词性的词频
|
||||
*
|
||||
* @param nature 字符串词性
|
||||
* @return 词频
|
||||
* @deprecated 推荐使用Nature参数!
|
||||
*/
|
||||
public int getNatureFrequency(String nature) {
|
||||
try {
|
||||
Nature pos = Nature.create(nature);
|
||||
return getNatureFrequency(pos);
|
||||
} catch (IllegalArgumentException e) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取词性的词频
|
||||
*
|
||||
* @param nature 词性
|
||||
* @return 词频
|
||||
*/
|
||||
public int getNatureFrequency(final Nature nature) {
|
||||
int i = 0;
|
||||
for (Nature pos : this.nature) {
|
||||
if (nature == pos) {
|
||||
return frequency[i];
|
||||
}
|
||||
++i;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否有某个词性
|
||||
*
|
||||
* @param nature
|
||||
* @return
|
||||
*/
|
||||
public boolean hasNature(Nature nature) {
|
||||
return getNatureFrequency(nature) > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否有以某个前缀开头的词性
|
||||
*
|
||||
* @param prefix 词性前缀,比如u会查询是否有ude, uzhe等等
|
||||
* @return
|
||||
*/
|
||||
public boolean hasNatureStartsWith(String prefix) {
|
||||
for (Nature n : nature) {
|
||||
if (n.startsWith(prefix)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
final StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < nature.length; ++i) {
|
||||
sb.append(nature[i]).append(' ').append(frequency[i]).append(' ');
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public void save(DataOutputStream out) throws IOException {
|
||||
out.writeInt(totalFrequency);
|
||||
out.writeInt(nature.length);
|
||||
for (int i = 0; i < nature.length; ++i) {
|
||||
out.writeInt(nature[i].ordinal());
|
||||
out.writeInt(frequency[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取词语的ID
|
||||
*
|
||||
* @param a 词语
|
||||
* @return ID, 如果不存在, 则返回-1
|
||||
*/
|
||||
public static int getWordID(String a) {
|
||||
return CoreDictionary.trie.exactMatchSearch(a);
|
||||
}
|
||||
|
||||
/**
|
||||
* 热更新核心词典<br>
|
||||
* 集群环境(或其他IOAdapter)需要自行删除缓存文件
|
||||
*
|
||||
* @return 是否成功
|
||||
*/
|
||||
public static boolean reload() {
|
||||
String path = CoreDictionary.PATH;
|
||||
IOUtil.deleteFile(path + Predefine.BIN_EXT);
|
||||
|
||||
return load(path);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,341 @@
|
||||
package com.hankcs.hanlp.seg;
|
||||
|
||||
|
||||
import com.hankcs.hanlp.algorithm.Viterbi;
|
||||
import com.hankcs.hanlp.collection.AhoCorasick.AhoCorasickDoubleArrayTrie;
|
||||
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionaryTransformMatrixDictionary;
|
||||
import com.hankcs.hanlp.dictionary.other.CharType;
|
||||
import com.hankcs.hanlp.seg.NShort.Path.AtomNode;
|
||||
import com.hankcs.hanlp.seg.common.Graph;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.hankcs.hanlp.seg.common.Vertex;
|
||||
import com.hankcs.hanlp.seg.common.WordNet;
|
||||
import com.hankcs.hanlp.utility.TextUtility;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.ListIterator;
|
||||
|
||||
|
||||
public abstract class WordBasedSegment extends Segment {
|
||||
|
||||
public WordBasedSegment() {
|
||||
}
|
||||
|
||||
protected static void generateWord(List<Vertex> linkedArray, WordNet wordNetOptimum) {
|
||||
fixResultByRule(linkedArray);
|
||||
wordNetOptimum.addAll(linkedArray);
|
||||
}
|
||||
|
||||
protected static void fixResultByRule(List<Vertex> linkedArray) {
|
||||
mergeContinueNumIntoOne(linkedArray);
|
||||
changeDelimiterPOS(linkedArray);
|
||||
splitMiddleSlashFromDigitalWords(linkedArray);
|
||||
checkDateElements(linkedArray);
|
||||
}
|
||||
|
||||
static void changeDelimiterPOS(List<Vertex> linkedArray) {
|
||||
Iterator var1 = linkedArray.iterator();
|
||||
|
||||
while (true) {
|
||||
Vertex vertex;
|
||||
do {
|
||||
if (!var1.hasNext()) {
|
||||
return;
|
||||
}
|
||||
|
||||
vertex = (Vertex) var1.next();
|
||||
} while (!vertex.realWord.equals("--") && !vertex.realWord.equals("—") && !vertex.realWord.equals("-"));
|
||||
|
||||
vertex.confirmNature(Nature.w);
|
||||
}
|
||||
}
|
||||
|
||||
private static void splitMiddleSlashFromDigitalWords(List<Vertex> linkedArray) {
|
||||
if (linkedArray.size() >= 2) {
|
||||
ListIterator<Vertex> listIterator = linkedArray.listIterator();
|
||||
Vertex next = (Vertex) listIterator.next();
|
||||
|
||||
for (Vertex current = next; listIterator.hasNext(); current = next) {
|
||||
next = (Vertex) listIterator.next();
|
||||
Nature currentNature = current.getNature();
|
||||
if (currentNature == Nature.nx && (next.hasNature(Nature.q) || next.hasNature(Nature.n))) {
|
||||
String[] param = current.realWord.split("-", 1);
|
||||
if (param.length == 2 && TextUtility.isAllNum(param[0]) && TextUtility.isAllNum(param[1])) {
|
||||
current = current.copy();
|
||||
current.realWord = param[0];
|
||||
current.confirmNature(Nature.m);
|
||||
listIterator.previous();
|
||||
listIterator.previous();
|
||||
listIterator.set(current);
|
||||
listIterator.next();
|
||||
listIterator.add(Vertex.newPunctuationInstance("-"));
|
||||
listIterator.add(Vertex.newNumberInstance(param[1]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkDateElements(List<Vertex> linkedArray) {
|
||||
if (linkedArray.size() >= 2) {
|
||||
ListIterator<Vertex> listIterator = linkedArray.listIterator();
|
||||
Vertex next = (Vertex) listIterator.next();
|
||||
|
||||
for (Vertex current = next; listIterator.hasNext(); current = next) {
|
||||
next = (Vertex) listIterator.next();
|
||||
if (TextUtility.isAllNum(current.realWord) || TextUtility.isAllChineseNum(current.realWord)) {
|
||||
String nextWord = next.realWord;
|
||||
if (nextWord.length() == 1 && "月日时分秒".contains(nextWord)
|
||||
|| nextWord.length() == 2 && nextWord.equals("月份")) {
|
||||
mergeDate(listIterator, next, current);
|
||||
} else if (nextWord.equals("年")) {
|
||||
if (TextUtility.isYearTime(current.realWord)) {
|
||||
mergeDate(listIterator, next, current);
|
||||
} else {
|
||||
current.confirmNature(Nature.m);
|
||||
}
|
||||
} else if (current.realWord.endsWith("点")) {
|
||||
current.confirmNature(Nature.t, true);
|
||||
} else {
|
||||
char[] tmpCharArray = current.realWord.toCharArray();
|
||||
String lastChar = String.valueOf(tmpCharArray[tmpCharArray.length - 1]);
|
||||
if (!"∶·././".contains(lastChar)) {
|
||||
current.confirmNature(Nature.m, true);
|
||||
} else if (current.realWord.length() > 1) {
|
||||
char last = current.realWord.charAt(current.realWord.length() - 1);
|
||||
current = Vertex.newNumberInstance(
|
||||
current.realWord.substring(0, current.realWord.length() - 1));
|
||||
listIterator.previous();
|
||||
listIterator.previous();
|
||||
listIterator.set(current);
|
||||
listIterator.next();
|
||||
listIterator.add(Vertex.newPunctuationInstance(String.valueOf(last)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private static void mergeDate(ListIterator<Vertex> listIterator, Vertex next, Vertex current) {
|
||||
current = Vertex.newTimeInstance(current.realWord + next.realWord);
|
||||
listIterator.previous();
|
||||
listIterator.previous();
|
||||
listIterator.set(current);
|
||||
listIterator.next();
|
||||
listIterator.next();
|
||||
listIterator.remove();
|
||||
}
|
||||
|
||||
protected static List<Term> convert(List<Vertex> vertexList) {
|
||||
return convert(vertexList, false);
|
||||
}
|
||||
|
||||
protected static Graph generateBiGraph(WordNet wordNet) {
|
||||
return wordNet.toGraph();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
private static List<AtomNode> atomSegment(String sSentence, int start, int end) {
|
||||
if (end < start) {
|
||||
throw new RuntimeException("start=" + start + " < end=" + end);
|
||||
} else {
|
||||
List<AtomNode> atomSegment = new ArrayList();
|
||||
int pCur = 0;
|
||||
StringBuilder sb = new StringBuilder();
|
||||
char[] charArray = sSentence.substring(start, end).toCharArray();
|
||||
int[] charTypeArray = new int[charArray.length];
|
||||
|
||||
for (int i = 0; i < charArray.length; ++i) {
|
||||
char c = charArray[i];
|
||||
charTypeArray[i] = CharType.get(c);
|
||||
if (c == '.' && i < charArray.length - 1 && CharType.get(charArray[i + 1]) == 9) {
|
||||
charTypeArray[i] = 9;
|
||||
} else if (c == '.' && i < charArray.length - 1 && charArray[i + 1] >= '0' && charArray[i + 1] <= '9') {
|
||||
charTypeArray[i] = 5;
|
||||
} else if (charTypeArray[i] == 8) {
|
||||
charTypeArray[i] = 5;
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
while (true) {
|
||||
while (pCur < charArray.length) {
|
||||
int nCurType = charTypeArray[pCur];
|
||||
if (nCurType != 7 && nCurType != 10 && nCurType != 6 && nCurType != 17) {
|
||||
if (pCur < charArray.length - 1 && (nCurType == 5 || nCurType == 9)) {
|
||||
sb.delete(0, sb.length());
|
||||
sb.append(charArray[pCur]);
|
||||
boolean reachEnd = true;
|
||||
|
||||
while (pCur < charArray.length - 1) {
|
||||
++pCur;
|
||||
int nNextType = charTypeArray[pCur];
|
||||
if (nNextType != nCurType) {
|
||||
reachEnd = false;
|
||||
break;
|
||||
}
|
||||
|
||||
sb.append(charArray[pCur]);
|
||||
}
|
||||
|
||||
atomSegment.add(new AtomNode(sb.toString(), nCurType));
|
||||
if (reachEnd) {
|
||||
++pCur;
|
||||
}
|
||||
} else {
|
||||
atomSegment.add(new AtomNode(charArray[pCur], nCurType));
|
||||
++pCur;
|
||||
}
|
||||
} else {
|
||||
String single = String.valueOf(charArray[pCur]);
|
||||
if (single.length() != 0) {
|
||||
atomSegment.add(new AtomNode(single, nCurType));
|
||||
}
|
||||
|
||||
++pCur;
|
||||
}
|
||||
}
|
||||
|
||||
return atomSegment;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void mergeContinueNumIntoOne(List<Vertex> linkedArray) {
|
||||
if (linkedArray.size() >= 2) {
|
||||
ListIterator<Vertex> listIterator = linkedArray.listIterator();
|
||||
Vertex next = (Vertex) listIterator.next();
|
||||
Vertex current = next;
|
||||
|
||||
while (true) {
|
||||
while (listIterator.hasNext()) {
|
||||
next = (Vertex) listIterator.next();
|
||||
if (!TextUtility.isAllNum(current.realWord) && !TextUtility.isAllChineseNum(current.realWord)
|
||||
|| !TextUtility.isAllNum(next.realWord) && !TextUtility.isAllChineseNum(next.realWord)) {
|
||||
current = next;
|
||||
} else {
|
||||
current = Vertex.newNumberInstance(current.realWord + next.realWord);
|
||||
listIterator.previous();
|
||||
listIterator.previous();
|
||||
listIterator.set(current);
|
||||
listIterator.next();
|
||||
listIterator.next();
|
||||
listIterator.remove();
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected void generateWordNet(final WordNet wordNetStorage) {
|
||||
final char[] charArray = wordNetStorage.charArray;
|
||||
DoubleArrayTrie.Searcher searcher = CoreDictionary.trie.getSearcher(charArray, 0);
|
||||
|
||||
while (searcher.next()) {
|
||||
wordNetStorage.add(searcher.begin + 1, new Vertex(new String(charArray, searcher.begin, searcher.length),
|
||||
(CoreDictionary.Attribute) searcher.value, searcher.index));
|
||||
}
|
||||
|
||||
if (this.config.forceCustomDictionary) {
|
||||
this.customDictionary.parseText(charArray, new AhoCorasickDoubleArrayTrie.IHit<CoreDictionary.Attribute>() {
|
||||
public void hit(int begin, int end, CoreDictionary.Attribute value) {
|
||||
wordNetStorage.add(begin + 1, new Vertex(new String(charArray, begin, end - begin), value));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
LinkedList<Vertex>[] vertexes = wordNetStorage.getVertexes();
|
||||
int i = 1;
|
||||
|
||||
while (true) {
|
||||
while (i < vertexes.length) {
|
||||
if (vertexes[i].isEmpty()) {
|
||||
int j;
|
||||
for (j = i + 1;
|
||||
j < vertexes.length - 1 && (vertexes[j].isEmpty() || CharType.get(charArray[j - 1]) == 11);
|
||||
++j) {
|
||||
}
|
||||
|
||||
wordNetStorage.add(i, quickAtomSegment(charArray, i - 1, j - 1));
|
||||
i = j;
|
||||
} else {
|
||||
i += ((Vertex) vertexes[i].getLast()).realWord.length();
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Term> decorateResultForIndexMode(List<Vertex> vertexList, WordNet wordNetAll) {
|
||||
List<Term> termList = new LinkedList();
|
||||
int line = 1;
|
||||
ListIterator<Vertex> listIterator = vertexList.listIterator();
|
||||
listIterator.next();
|
||||
int length = vertexList.size() - 2;
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
Vertex vertex = (Vertex) listIterator.next();
|
||||
Term termMain = convert(vertex);
|
||||
//termList.add(termMain);
|
||||
addTerms(termList, vertex, line - 1);
|
||||
termMain.offset = line - 1;
|
||||
if (vertex.realWord.length() > 2) {
|
||||
label43:
|
||||
for (int currentLine = line; currentLine < line + vertex.realWord.length(); ++currentLine) {
|
||||
Iterator iterator = wordNetAll.descendingIterator(currentLine);
|
||||
|
||||
while (true) {
|
||||
Vertex smallVertex;
|
||||
do {
|
||||
if (!iterator.hasNext()) {
|
||||
continue label43;
|
||||
}
|
||||
smallVertex = (Vertex) iterator.next();
|
||||
} while ((termMain.nature != Nature.mq || !smallVertex.hasNature(Nature.q))
|
||||
&& smallVertex.realWord.length() < this.config.indexMode);
|
||||
|
||||
if (smallVertex != vertex
|
||||
&& currentLine + smallVertex.realWord.length() <= line + vertex.realWord.length()) {
|
||||
listIterator.add(smallVertex);
|
||||
//Term termSub = convert(smallVertex);
|
||||
//termSub.offset = currentLine - 1;
|
||||
//termList.add(termSub);
|
||||
addTerms(termList, smallVertex, currentLine - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
line += vertex.realWord.length();
|
||||
}
|
||||
|
||||
return termList;
|
||||
}
|
||||
|
||||
protected static void speechTagging(List<Vertex> vertexList) {
|
||||
Viterbi.compute(vertexList, CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary);
|
||||
}
|
||||
|
||||
protected void addTerms(List<Term> terms, Vertex vertex, int offset) {
|
||||
for (int i = 0; i < vertex.attribute.nature.length; i++) {
|
||||
Term term = new Term(vertex.realWord, vertex.attribute.nature[i]);
|
||||
term.setFrequency(vertex.attribute.frequency[i]);
|
||||
term.offset = offset;
|
||||
terms.add(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package com.hankcs.hanlp.seg.common;
|
||||
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.dictionary.CustomDictionary;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class Term {
|
||||
|
||||
public String word;
|
||||
|
||||
public Nature nature;
|
||||
public int offset;
|
||||
public int frequency = 0;
|
||||
|
||||
public Term(String word, Nature nature) {
|
||||
this.word = word;
|
||||
this.nature = nature;
|
||||
}
|
||||
|
||||
public Term(String word, Nature nature, int offset) {
|
||||
this.word = word;
|
||||
this.nature = nature;
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
public Term(String word, Nature nature, int offset, int frequency) {
|
||||
this.word = word;
|
||||
this.nature = nature;
|
||||
this.offset = offset;
|
||||
this.frequency = frequency;
|
||||
}
|
||||
|
||||
public int length() {
|
||||
return this.word.length();
|
||||
}
|
||||
|
||||
public int getFrequency() {
|
||||
if (frequency > 0) {
|
||||
return frequency;
|
||||
}
|
||||
String wordOri = word.toLowerCase();
|
||||
CoreDictionary.Attribute attribute = HanlpHelper.getDynamicCustomDictionary().get(wordOri);
|
||||
if (attribute == null) {
|
||||
attribute = CoreDictionary.get(wordOri);
|
||||
if (attribute == null) {
|
||||
attribute = CustomDictionary.get(wordOri);
|
||||
}
|
||||
}
|
||||
if (attribute != null && nature != null && attribute.hasNature(nature)) {
|
||||
return attribute.getNatureFrequency(nature);
|
||||
}
|
||||
return attribute == null ? 0 : attribute.totalFrequency;
|
||||
}
|
||||
|
||||
public boolean equals(Object obj) {
|
||||
if (obj instanceof Term) {
|
||||
Term term = (Term) obj;
|
||||
if (this.nature == term.nature && this.word.equals(term.word)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return super.equals(obj);
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class ChatConfig {
|
||||
|
||||
/**
|
||||
* database auto-increment primary key
|
||||
*/
|
||||
private Long id;
|
||||
|
||||
private Long modelId;
|
||||
|
||||
/**
|
||||
* the chatDetailConfig about the model
|
||||
*/
|
||||
private ChatDetailConfigReq chatDetailConfig;
|
||||
|
||||
/**
|
||||
* the chatAggConfig about the model
|
||||
*/
|
||||
private ChatAggConfigReq chatAggConfig;
|
||||
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
|
||||
/**
|
||||
* available status
|
||||
*/
|
||||
private StatusEnum status;
|
||||
|
||||
/**
|
||||
* about createdBy, createdAt, updatedBy, updatedAt
|
||||
*/
|
||||
private RecordInfo recordInfo;
|
||||
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatConfigFilterInternal {
|
||||
|
||||
private Long id;
|
||||
private Long modelId;
|
||||
private Integer status;
|
||||
}
|
||||
@@ -1,15 +1,18 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import java.util.Objects;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Data
|
||||
public class Agent extends RecordInfo {
|
||||
@@ -23,6 +26,7 @@ public class Agent extends RecordInfo {
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
||||
@@ -45,4 +49,27 @@ public class Agent extends RecordInfo {
|
||||
return enableSearch != null && enableSearch == 1;
|
||||
}
|
||||
|
||||
public boolean containsAllModel(Set<Long> detectModelIds) {
|
||||
return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L);
|
||||
}
|
||||
|
||||
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
|
||||
List<String> tools = this.getTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(AgentToolType agentToolType) {
|
||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
|
||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.AllArgsConstructor;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
public enum AgentToolType {
|
||||
NL2SQL_RULE,
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
public class DefaultSemanticConfig {
|
||||
|
||||
@Value("${semantic.url.prefix:http://localhost:8081}")
|
||||
private String semanticUrl;
|
||||
|
||||
@Value("${searchByStruct.path:/api/semantic/query/struct}")
|
||||
private String searchByStructPath;
|
||||
|
||||
@Value("${searchByStruct.path:/api/semantic/query/multiStruct}")
|
||||
private String searchByMultiStructPath;
|
||||
|
||||
@Value("${searchByStruct.path:/api/semantic/query/sql}")
|
||||
private String searchBySqlPath;
|
||||
|
||||
@Value("${searchByStruct.path:/api/semantic/query/queryDimValue}")
|
||||
private String queryDimValuePath;
|
||||
|
||||
@Value("${fetchModelSchemaPath.path:/api/semantic/schema}")
|
||||
private String fetchModelSchemaPath;
|
||||
|
||||
@Value("${fetchModelList.path:/api/semantic/schema/dimension/page}")
|
||||
private String fetchDimensionPagePath;
|
||||
|
||||
@Value("${fetchModelList.path:/api/semantic/schema/metric/page}")
|
||||
private String fetchMetricPagePath;
|
||||
|
||||
@Value("${fetchModelList.path:/api/semantic/schema/domain/list}")
|
||||
private String fetchDomainListPath;
|
||||
|
||||
@Value("${fetchModelList.path:/api/semantic/schema/model/list}")
|
||||
private String fetchModelListPath;
|
||||
|
||||
@Value("${explain.path:/api/semantic/query/explain}")
|
||||
private String explainPath;
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import java.io.FileNotFoundException;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class LocalFileConfig {
|
||||
|
||||
|
||||
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
||||
private String dictDirectoryLatest;
|
||||
|
||||
@Value("${dict.directory.backup:./dict/backup}")
|
||||
private String dictDirectoryBackup;
|
||||
|
||||
public String getDictDirectoryLatest() {
|
||||
return getResourceDir() + dictDirectoryLatest;
|
||||
}
|
||||
|
||||
public String getDictDirectoryBackup() {
|
||||
return dictDirectoryBackup;
|
||||
}
|
||||
|
||||
private String getResourceDir() {
|
||||
String hanlpPropertiesPath = "";
|
||||
try {
|
||||
hanlpPropertiesPath = HanlpHelper.getHanlpPropertiesPath();
|
||||
} catch (FileNotFoundException e) {
|
||||
log.warn("getResourceDir, e:", e);
|
||||
}
|
||||
return hanlpPropertiesPath;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -1,22 +1,14 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -24,6 +16,10 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
@@ -32,23 +28,23 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
return;
|
||||
}
|
||||
doCorrect(queryReq, semanticParseInfo);
|
||||
doCorrect(queryContext, semanticParseInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Set<Long> modelIds) {
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
@@ -101,12 +97,12 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(modelIds);
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, modelIds);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
@@ -131,8 +127,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Set<Long> modelIds) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(modelIds);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
/**
|
||||
@@ -11,7 +11,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class FromCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String modelName = semanticParseInfo.getModel().getName();
|
||||
String correctSql = SqlParserReplaceHelper
|
||||
.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||
@@ -1,21 +1,18 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||
@@ -24,19 +21,19 @@ import java.util.stream.Collectors;
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addGroupByFields(semanticParseInfo);
|
||||
addGroupByFields(queryContext, semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
|
||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||
.flatMap(
|
||||
@@ -77,15 +74,15 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
addAggregate(semanticParseInfo);
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(SemanticParseInfo semanticParseInfo) {
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(semanticParseInfo);
|
||||
addAggregateToMetric(queryContext, semanticParseInfo);
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,14 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -24,20 +20,20 @@ import org.springframework.util.CollectionUtils;
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
//add aggregate to all metric
|
||||
addHaving(semanticParseInfo);
|
||||
addHaving(queryContext, semanticParseInfo);
|
||||
|
||||
//add having expression filed to select
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
@@ -24,7 +24,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
@@ -34,7 +34,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||
|
||||
correctFieldName(semanticParseInfo);
|
||||
correctFieldName(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -50,8 +50,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getModel().getModelIds());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -14,7 +14,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
|
||||
/**
|
||||
* A semantic corrector checks validity of extracted semantic information and
|
||||
* performs correction and optimization if needed.
|
||||
*/
|
||||
public interface SemanticCorrector {
|
||||
|
||||
void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
}
|
||||
@@ -1,20 +1,24 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -23,13 +27,6 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@@ -37,19 +34,19 @@ import java.util.stream.Collectors;
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addDateIfNotExist(semanticParseInfo);
|
||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addQueryFilter(queryReq, semanticParseInfo);
|
||||
addQueryFilter(queryContext, semanticParseInfo);
|
||||
|
||||
updateFieldValueByTechName(semanticParseInfo);
|
||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
@@ -72,11 +69,11 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
@@ -100,8 +97,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class DatabaseMapResult extends MapResult {
|
||||
|
||||
private SchemaElement schemaElement;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DatabaseMapResult that = (DatabaseMapResult) o;
|
||||
return Objects.equal(name, that.name) && Objects.equal(schemaElement, that.schemaElement);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(name, schemaElement);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class DictConfig {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private List<DimValueInfo> dimValueInfoList;
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
public enum DictUpdateMode {
|
||||
|
||||
OFFLINE_FULL("OFFLINE_FULL"),
|
||||
OFFLINE_MODEL("OFFLINE_MODEL"),
|
||||
REALTIME_ADD("REALTIME_ADD"),
|
||||
REALTIME_DELETE("REALTIME_DELETE"),
|
||||
NOT_SUPPORT("NOT_SUPPORT");
|
||||
|
||||
private String value;
|
||||
|
||||
DictUpdateMode(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public static DictUpdateMode of(String value) {
|
||||
for (DictUpdateMode item : DictUpdateMode.values()) {
|
||||
if (item.value.equalsIgnoreCase(value)) {
|
||||
return item;
|
||||
}
|
||||
}
|
||||
return DictUpdateMode.NOT_SUPPORT;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.util.Objects;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
/***
|
||||
* word nature
|
||||
*/
|
||||
@Data
|
||||
@ToString
|
||||
public class DictWord {
|
||||
|
||||
private String word;
|
||||
private String nature;
|
||||
private String natureWithFrequency;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DictWord that = (DictWord) o;
|
||||
return Objects.equals(word, that.word) && Objects.equals(natureWithFrequency, that.natureWithFrequency);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(word, natureWithFrequency);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.dictionary;
|
||||
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Dictionary Attribute Util
|
||||
*/
|
||||
public class DictionaryAttributeUtil {
|
||||
|
||||
public static CoreDictionary.Attribute getAttribute(CoreDictionary.Attribute old, CoreDictionary.Attribute add) {
|
||||
Map<Nature, Integer> map = new HashMap<>();
|
||||
IntStream.range(0, old.nature.length).boxed().forEach(i -> map.put(old.nature[i], old.frequency[i]));
|
||||
IntStream.range(0, add.nature.length).boxed().forEach(i -> map.put(add.nature[i], add.frequency[i]));
|
||||
List<Map.Entry<Nature, Integer>> list = new LinkedList<Map.Entry<Nature, Integer>>(map.entrySet());
|
||||
Collections.sort(list, new Comparator<Map.Entry<Nature, Integer>>() {
|
||||
public int compare(Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
|
||||
return o2.getValue() - o1.getValue();
|
||||
}
|
||||
});
|
||||
CoreDictionary.Attribute attribute = new CoreDictionary.Attribute(
|
||||
list.stream().map(i -> i.getKey()).collect(Collectors.toList()).toArray(new Nature[0]),
|
||||
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
|
||||
list.stream().map(i -> i.getValue()).findFirst().get());
|
||||
if (old.original != null || add.original != null) {
|
||||
attribute.original = add.original != null ? add.original : old.original;
|
||||
}
|
||||
return attribute;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimValue2DictCommand {
|
||||
|
||||
private DictUpdateMode updateMode;
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
private Map<Long, List<Long>> modelAndDimPair = new HashMap<>();
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.Set;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimValueDictInfo {
|
||||
|
||||
private Long id;
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
private String command;
|
||||
|
||||
private TaskStatusEnum status;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private Date createdAt;
|
||||
|
||||
private Long elapsedMs;
|
||||
|
||||
private Set<Long> dimIds;
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import java.util.List;
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
public class DimValueInfo {
|
||||
|
||||
/**
|
||||
* metricId、DimensionId、domainId
|
||||
*/
|
||||
private Long itemId;
|
||||
|
||||
/**
|
||||
* type: IntentionTypeEnum
|
||||
* temporarily only supports dimension-related information
|
||||
*/
|
||||
@NotNull
|
||||
private TypeEnums type = TypeEnums.DIMENSION;
|
||||
|
||||
private List<String> blackList;
|
||||
private List<String> whiteList;
|
||||
private List<String> ruleList;
|
||||
private Boolean isDictInfo;
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class EmbeddingResult extends MapResult {
|
||||
|
||||
private String id;
|
||||
|
||||
private double distance;
|
||||
|
||||
private Map<String, String> metadata;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
EmbeddingResult that = (EmbeddingResult) o;
|
||||
return Objects.equal(id, that.id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(id);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface FileHandler {
|
||||
|
||||
/**
|
||||
* backup files to a specific directory
|
||||
* config: dict.directory.backup
|
||||
*
|
||||
* @param fileName
|
||||
*/
|
||||
void backupFile(String fileName);
|
||||
|
||||
/**
|
||||
* create a directory
|
||||
*
|
||||
* @param path
|
||||
*/
|
||||
void createDir(String path);
|
||||
|
||||
Boolean existPath(String path);
|
||||
|
||||
/**
|
||||
* write data to a specific file,
|
||||
* config dir: dict.directory.latest
|
||||
*
|
||||
* @param data
|
||||
* @param fileName
|
||||
* @param append
|
||||
*/
|
||||
void writeFile(List<String> data, String fileName, Boolean append);
|
||||
|
||||
/**
|
||||
* get the knowledge file root directory
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
String getDictRootPath();
|
||||
|
||||
/**
|
||||
* delete dictionary file
|
||||
* automatic backup
|
||||
*
|
||||
* @param fileName
|
||||
* @return
|
||||
*/
|
||||
Boolean deleteDictFile(String fileName);
|
||||
|
||||
/**
|
||||
* delete files directly without backup
|
||||
*
|
||||
* @param fileName
|
||||
*/
|
||||
void deleteFile(String fileName);
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.corpus.io.IIOAdapter;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.URI;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
@Slf4j
|
||||
public class HadoopFileIOAdapter implements IIOAdapter {
|
||||
|
||||
@Override
|
||||
public InputStream open(String path) throws IOException {
|
||||
log.info("open:{}", path);
|
||||
Configuration conf = new Configuration();
|
||||
FileSystem fs = FileSystem.get(URI.create(path), conf);
|
||||
return fs.open(new Path(path));
|
||||
}
|
||||
|
||||
@Override
|
||||
public OutputStream create(String path) throws IOException {
|
||||
log.info("create:{}", path);
|
||||
Configuration conf = new Configuration();
|
||||
FileSystem fs = FileSystem.get(URI.create(path), conf);
|
||||
return fs.create(new Path(path));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class HanlpMapResult extends MapResult {
|
||||
|
||||
private List<String> natures;
|
||||
private int offset = 0;
|
||||
|
||||
private double similarity;
|
||||
|
||||
public HanlpMapResult(String name, List<String> natures, String detectWord) {
|
||||
this.name = name;
|
||||
this.natures = natures;
|
||||
this.detectWord = detectWord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
HanlpMapResult hanlpMapResult = (HanlpMapResult) o;
|
||||
return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(name, natures);
|
||||
}
|
||||
|
||||
public void setOffset(int offset) {
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Data
|
||||
@Service
|
||||
public class LoadRemoveService {
|
||||
|
||||
@Value("${mapper.remove.agentId:}")
|
||||
private Integer mapperRemoveAgentId;
|
||||
|
||||
@Value("${mapper.remove.nature.prefix:}")
|
||||
private String mapperRemoveNaturePrefix;
|
||||
|
||||
public List removeNatures(List value, Integer agentId, Set<Long> detectModelIds) {
|
||||
if (CollectionUtils.isEmpty(value)) {
|
||||
return value;
|
||||
}
|
||||
List<String> resultList = new ArrayList<>(value);
|
||||
if (!CollectionUtils.isEmpty(detectModelIds)) {
|
||||
resultList.removeIf(nature -> {
|
||||
if (Objects.isNull(nature)) {
|
||||
return false;
|
||||
}
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return !detectModelIds.contains(modelId);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
if (Objects.nonNull(mapperRemoveAgentId)
|
||||
&& mapperRemoveAgentId.equals(agentId)
|
||||
&& StringUtils.isNotBlank(mapperRemoveNaturePrefix)) {
|
||||
resultList.removeIf(nature -> {
|
||||
if (Objects.isNull(nature)) {
|
||||
return false;
|
||||
}
|
||||
return nature.startsWith(mapperRemoveNaturePrefix);
|
||||
});
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.config.LocalFileConfig;
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class LocalFileHandler implements FileHandler {
|
||||
|
||||
private final LocalFileConfig localFileConfig;
|
||||
|
||||
public LocalFileHandler(LocalFileConfig localFileConfig) {
|
||||
this.localFileConfig = localFileConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void backupFile(String fileName) {
|
||||
String dictDirectoryBackup = localFileConfig.getDictDirectoryBackup();
|
||||
if (!existPath(dictDirectoryBackup)) {
|
||||
createDir(dictDirectoryBackup);
|
||||
}
|
||||
|
||||
String source = localFileConfig.getDictDirectoryLatest() + "/" + fileName;
|
||||
String target = dictDirectoryBackup + "/" + fileName;
|
||||
Path sourcePath = Paths.get(source);
|
||||
Path targetPath = Paths.get(target);
|
||||
try {
|
||||
Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING);
|
||||
log.info("backupFile successfully! path:{}", targetPath.toAbsolutePath());
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to copy file: " + e.getMessage());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createDir(String directoryPath) {
|
||||
Path path = Paths.get(directoryPath);
|
||||
try {
|
||||
Files.createDirectories(path);
|
||||
log.info("Directory created successfully!");
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to create directory: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteFile(String filePath) {
|
||||
Path path = Paths.get(filePath);
|
||||
try {
|
||||
Files.delete(path);
|
||||
log.info("File:{} deleted successfully!", getAbsolutePath(filePath));
|
||||
} catch (IOException e) {
|
||||
log.warn("Failed to delete file:{}, e:", getAbsolutePath(filePath), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean existPath(String pathStr) {
|
||||
Path path = Paths.get(pathStr);
|
||||
if (Files.exists(path)) {
|
||||
log.info("path:{} exists!", getAbsolutePath(pathStr));
|
||||
return true;
|
||||
} else {
|
||||
log.info("path:{} not exists!", getAbsolutePath(pathStr));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeFile(List<String> lines, String fileName, Boolean append) {
|
||||
String dictDirectoryLatest = localFileConfig.getDictDirectoryLatest();
|
||||
if (!existPath(dictDirectoryLatest)) {
|
||||
createDir(dictDirectoryLatest);
|
||||
}
|
||||
String filePath = dictDirectoryLatest + "/" + fileName;
|
||||
if (existPath(filePath)) {
|
||||
backupFile(fileName);
|
||||
}
|
||||
try (BufferedWriter writer = getWriter(filePath, append)) {
|
||||
if (!CollectionUtils.isEmpty(lines)) {
|
||||
for (String line : lines) {
|
||||
writer.write(line);
|
||||
writer.newLine();
|
||||
}
|
||||
}
|
||||
log.info("File:{} written successfully!", getAbsolutePath(filePath));
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to write file:{}, e:", getAbsolutePath(filePath), e);
|
||||
}
|
||||
}
|
||||
|
||||
public String getAbsolutePath(String path) {
|
||||
return Paths.get(path).toAbsolutePath().toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getDictRootPath() {
|
||||
return Paths.get(localFileConfig.getDictDirectoryLatest()).toAbsolutePath().toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean deleteDictFile(String fileName) {
|
||||
backupFile(fileName);
|
||||
deleteFile(localFileConfig.getDictDirectoryLatest() + "/" + fileName);
|
||||
return true;
|
||||
}
|
||||
|
||||
private BufferedWriter getWriter(String filePath, Boolean append) throws IOException {
|
||||
if (append) {
|
||||
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND);
|
||||
}
|
||||
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.io.Serializable;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class MapResult implements Serializable {
|
||||
|
||||
protected String name;
|
||||
protected String detectWord;
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.io.Serializable;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
public class ModelInfoStat implements Serializable {
|
||||
|
||||
private long modelCount;
|
||||
|
||||
private long metricModelCount;
|
||||
|
||||
private long dimensionModelCount;
|
||||
|
||||
private long dimensionValueModelCount;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import static com.hankcs.hanlp.utility.Predefine.logger;
|
||||
|
||||
import com.hankcs.hanlp.HanLP;
|
||||
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
|
||||
import com.hankcs.hanlp.corpus.io.ByteArray;
|
||||
import com.hankcs.hanlp.corpus.io.IOUtil;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
import com.hankcs.hanlp.dictionary.other.CharTable;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.hankcs.hanlp.utility.LexiconUtility;
|
||||
import com.hankcs.hanlp.utility.Predefine;
|
||||
import com.hankcs.hanlp.utility.TextUtility;
|
||||
import com.tencent.supersonic.chat.core.knowledge.dictionary.DictionaryAttributeUtil;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
|
||||
public static int MAX_SIZE = 10;
|
||||
public static Boolean removeDuplicates = true;
|
||||
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES = new ConcurrentHashMap<>();
|
||||
private static boolean addToSuggesterTrie = true;
|
||||
|
||||
public MultiCustomDictionary() {
|
||||
this(HanLP.Config.CustomDictionaryPath);
|
||||
}
|
||||
|
||||
public MultiCustomDictionary(String... path) {
|
||||
super(path);
|
||||
}
|
||||
|
||||
/***
|
||||
* load dictionary
|
||||
* @param path
|
||||
* @param defaultNature
|
||||
* @param map
|
||||
* @param customNatureCollector
|
||||
* @param addToSuggeterTrie
|
||||
* @return
|
||||
*/
|
||||
public static boolean load(String path, Nature defaultNature, TreeMap<String, CoreDictionary.Attribute> map,
|
||||
LinkedHashSet<Nature> customNatureCollector, boolean addToSuggeterTrie) {
|
||||
try {
|
||||
String splitter = "\\s";
|
||||
if (path.endsWith(".csv")) {
|
||||
splitter = ",";
|
||||
}
|
||||
|
||||
BufferedReader br = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
|
||||
boolean firstLine = true;
|
||||
|
||||
while (true) {
|
||||
String[] param;
|
||||
do {
|
||||
String line;
|
||||
if ((line = br.readLine()) == null) {
|
||||
br.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (firstLine) {
|
||||
line = IOUtil.removeUTF8BOM(line);
|
||||
firstLine = false;
|
||||
}
|
||||
|
||||
param = line.split(splitter);
|
||||
} while (param[0].length() == 0);
|
||||
|
||||
if (HanLP.Config.Normalization) {
|
||||
param[0] = CharTable.convert(param[0]);
|
||||
}
|
||||
|
||||
int natureCount = (param.length - 1) / 2;
|
||||
CoreDictionary.Attribute attribute;
|
||||
boolean isLetters = isLetters(param[0]);
|
||||
String original = null;
|
||||
String word = getWordBySpace(param[0]);
|
||||
if (isLetters) {
|
||||
original = word;
|
||||
word = word.toLowerCase();
|
||||
}
|
||||
if (natureCount == 0) {
|
||||
attribute = new CoreDictionary.Attribute(defaultNature);
|
||||
} else {
|
||||
attribute = new CoreDictionary.Attribute(natureCount);
|
||||
|
||||
for (int i = 0; i < natureCount; ++i) {
|
||||
attribute.nature[i] = LexiconUtility.convertStringToNature(param[1 + 2 * i],
|
||||
customNatureCollector);
|
||||
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
|
||||
attribute.totalFrequency += attribute.frequency[i];
|
||||
}
|
||||
}
|
||||
attribute.original = original;
|
||||
|
||||
if (removeDuplicates && map.containsKey(word)) {
|
||||
attribute = DictionaryAttributeUtil.getAttribute(map.get(word), attribute);
|
||||
}
|
||||
map.put(word, attribute);
|
||||
if (addToSuggeterTrie) {
|
||||
SearchService.put(word, attribute);
|
||||
}
|
||||
for (int i = 0; i < attribute.nature.length; i++) {
|
||||
Nature nature = attribute.nature[i];
|
||||
PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
|
||||
if (Objects.isNull(priorityQueue)) {
|
||||
priorityQueue = new PriorityQueue<>(MAX_SIZE,
|
||||
Comparator.comparingInt(Term::getFrequency).reversed());
|
||||
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
|
||||
}
|
||||
Term term = new Term(word, nature);
|
||||
term.setFrequency(attribute.frequency[i]);
|
||||
if (!priorityQueue.contains(term) && priorityQueue.size() < MAX_SIZE) {
|
||||
priorityQueue.add(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception var12) {
|
||||
logger.severe("自定义词典" + path + "读取错误!" + var12);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean load(String... path) {
|
||||
this.path = path;
|
||||
long start = System.currentTimeMillis();
|
||||
if (!this.loadMainDictionary(path[0])) {
|
||||
Predefine.logger.warning("自定义词典" + Arrays.toString(path) + "加载失败");
|
||||
return false;
|
||||
} else {
|
||||
Predefine.logger.info(
|
||||
"自定义词典加载成功:" + this.dat.size() + "个词条,耗时" + (System.currentTimeMillis() - start) + "ms");
|
||||
this.path = path;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/***
|
||||
* load main dictionary
|
||||
* @param mainPath
|
||||
* @param path
|
||||
* @param dat
|
||||
* @param isCache
|
||||
* @param addToSuggestTrie
|
||||
* @return
|
||||
*/
|
||||
public static boolean loadMainDictionary(String mainPath, String[] path,
|
||||
DoubleArrayTrie<CoreDictionary.Attribute> dat, boolean isCache, boolean addToSuggestTrie) {
|
||||
Predefine.logger.info("自定义词典开始加载:" + mainPath);
|
||||
if (loadDat(mainPath, dat)) {
|
||||
return true;
|
||||
} else {
|
||||
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
|
||||
LinkedHashSet customNatureCollector = new LinkedHashSet();
|
||||
|
||||
try {
|
||||
for (String p : path) {
|
||||
Nature defaultNature = Nature.n;
|
||||
File file = new File(p);
|
||||
String fileName = file.getName();
|
||||
int cut = fileName.lastIndexOf(32);
|
||||
if (cut > 0) {
|
||||
String nature = fileName.substring(cut + 1);
|
||||
p = file.getParent() + File.separator + fileName.substring(0, cut);
|
||||
|
||||
try {
|
||||
defaultNature = LexiconUtility.convertStringToNature(nature, customNatureCollector);
|
||||
} catch (Exception var16) {
|
||||
Predefine.logger.severe("配置文件【" + p + "】写错了!" + var16);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Predefine.logger.info("以默认词性[" + defaultNature + "]加载自定义词典" + p + "中……");
|
||||
boolean success = load(p, defaultNature, map, customNatureCollector, addToSuggestTrie);
|
||||
if (!success) {
|
||||
Predefine.logger.warning("失败:" + p);
|
||||
}
|
||||
}
|
||||
|
||||
if (map.size() == 0) {
|
||||
Predefine.logger.warning("没有加载到任何词条");
|
||||
map.put("未##它", null);
|
||||
}
|
||||
|
||||
logger.info("正在构建DoubleArrayTrie……");
|
||||
dat.build(map);
|
||||
if (addToSuggestTrie) {
|
||||
// SearchService.save();
|
||||
}
|
||||
if (isCache) {
|
||||
// 缓存成dat文件,下次加载会快很多
|
||||
logger.info("正在缓存词典为dat文件……");
|
||||
// 缓存值文件
|
||||
List<CoreDictionary.Attribute> attributeList = new LinkedList<CoreDictionary.Attribute>();
|
||||
for (Map.Entry<String, CoreDictionary.Attribute> entry : map.entrySet()) {
|
||||
attributeList.add(entry.getValue());
|
||||
}
|
||||
|
||||
DataOutputStream out = new DataOutputStream(
|
||||
new BufferedOutputStream(IOUtil.newOutputStream(mainPath + ".bin")));
|
||||
if (customNatureCollector.isEmpty()) {
|
||||
for (int i = Nature.begin.ordinal() + 1; i < Nature.values().length; ++i) {
|
||||
Nature nature = Nature.values()[i];
|
||||
if (Objects.nonNull(nature)) {
|
||||
customNatureCollector.add(nature);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
IOUtil.writeCustomNature(out, customNatureCollector);
|
||||
out.writeInt(attributeList.size());
|
||||
|
||||
for (CoreDictionary.Attribute attribute : attributeList) {
|
||||
attribute.save(out);
|
||||
}
|
||||
|
||||
dat.save(out);
|
||||
out.close();
|
||||
}
|
||||
} catch (FileNotFoundException var17) {
|
||||
logger.severe("自定义词典" + mainPath + "不存在!" + var17);
|
||||
return false;
|
||||
} catch (IOException var18) {
|
||||
logger.severe("自定义词典" + mainPath + "读取错误!" + var18);
|
||||
return false;
|
||||
} catch (Exception var19) {
|
||||
logger.warning("自定义词典" + mainPath + "缓存失败!\n" + TextUtility.exceptionToString(var19));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean loadMainDictionary(String mainPath) {
|
||||
return loadMainDictionary(mainPath, this.path, this.dat, true, addToSuggesterTrie);
|
||||
}
|
||||
|
||||
public static boolean loadDat(String path, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
|
||||
return loadDat(path, HanLP.Config.CustomDictionaryPath, dat);
|
||||
}
|
||||
|
||||
public static boolean loadDat(String path, String[] customDicPath, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
|
||||
try {
|
||||
if (HanLP.Config.CustomDictionaryAutoRefreshCache && isDicNeedUpdate(path, customDicPath)) {
|
||||
return false;
|
||||
} else {
|
||||
ByteArray byteArray = ByteArray.createByteArray(path + ".bin");
|
||||
if (byteArray == null) {
|
||||
return false;
|
||||
} else {
|
||||
int size = byteArray.nextInt();
|
||||
if (size < 0) {
|
||||
while (true) {
|
||||
++size;
|
||||
if (size > 0) {
|
||||
size = byteArray.nextInt();
|
||||
break;
|
||||
}
|
||||
|
||||
Nature.create(byteArray.nextString());
|
||||
}
|
||||
}
|
||||
|
||||
CoreDictionary.Attribute[] attributes = new CoreDictionary.Attribute[size];
|
||||
Nature[] natureIndexArray = Nature.values();
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
int currentTotalFrequency = byteArray.nextInt();
|
||||
int length = byteArray.nextInt();
|
||||
attributes[i] = new CoreDictionary.Attribute(length);
|
||||
attributes[i].totalFrequency = currentTotalFrequency;
|
||||
|
||||
for (int j = 0; j < length; ++j) {
|
||||
attributes[i].nature[j] = natureIndexArray[byteArray.nextInt()];
|
||||
attributes[i].frequency[j] = byteArray.nextInt();
|
||||
}
|
||||
}
|
||||
|
||||
if (!dat.load(byteArray, attributes)) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception var11) {
|
||||
logger.warning("读取失败,问题发生在" + TextUtility.exceptionToString(var11));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean isLetters(String str) {
|
||||
char[] chars = str.toCharArray();
|
||||
if (chars.length <= 1) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < chars.length; i++) {
|
||||
if ((chars[i] >= 'A' && chars[i] <= 'Z')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean isLowerLetter(String str) {
|
||||
char[] chars = str.toCharArray();
|
||||
for (int i = 0; i < chars.length; i++) {
|
||||
if ((chars[i] >= 'a' && chars[i] <= 'z')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static String getWordBySpace(String word) {
|
||||
if (word.contains(HanlpHelper.SPACE_SPILT)) {
|
||||
return word.replace(HanlpHelper.SPACE_SPILT, " ");
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
public boolean reload() {
|
||||
if (this.path != null && this.path.length != 0) {
|
||||
IOUtil.deleteFile(this.path[0] + ".bin");
|
||||
Boolean loadCacheOk = this.loadDat(this.path[0], this.path, this.dat);
|
||||
if (!loadCacheOk) {
|
||||
return this.loadMainDictionary(this.path[0], this.path, this.dat, true, addToSuggesterTrie);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
public boolean insert(String word, String natureWithFrequency) {
|
||||
if (word == null) {
|
||||
return false;
|
||||
} else {
|
||||
if (HanLP.Config.Normalization) {
|
||||
word = CharTable.convert(word);
|
||||
}
|
||||
CoreDictionary.Attribute att = natureWithFrequency == null ? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(natureWithFrequency);
|
||||
boolean isLetters = isLetters(word);
|
||||
word = getWordBySpace(word);
|
||||
String original = null;
|
||||
if (isLetters) {
|
||||
original = word;
|
||||
word = word.toLowerCase();
|
||||
}
|
||||
if (att == null) {
|
||||
return false;
|
||||
} else if (this.dat.containsKey(word)) {
|
||||
att.original = original;
|
||||
att = DictionaryAttributeUtil.getAttribute(this.dat.get(word), att);
|
||||
this.dat.set(word, att);
|
||||
// return true;
|
||||
} else {
|
||||
if (this.trie == null) {
|
||||
this.trie = new BinTrie();
|
||||
}
|
||||
att.original = original;
|
||||
if (this.trie.containsKey(word)) {
|
||||
att = DictionaryAttributeUtil.getAttribute(this.trie.get(word), att);
|
||||
}
|
||||
this.trie.put(word, att);
|
||||
// return true;
|
||||
}
|
||||
if (addToSuggesterTrie) {
|
||||
SearchService.put(word, att);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BaseNode;
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.chat.core.knowledge.dictionary.DictionaryAttributeUtil;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.TreeSet;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class SearchService {
|
||||
|
||||
public static final int SEARCH_SIZE = 200;
|
||||
private static BinTrie<List<String>> trie;
|
||||
private static BinTrie<List<String>> suffixTrie;
|
||||
|
||||
static {
|
||||
trie = new BinTrie<>();
|
||||
suffixTrie = new BinTrie<>();
|
||||
}
|
||||
|
||||
/***
|
||||
* prefix Search
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
|
||||
return prefixSearch(key, limit, agentId, trie, detectModelIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Integer agentId,
|
||||
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
|
||||
return result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
return new HanlpMapResult(name, entry.getValue(), key);
|
||||
}
|
||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.limit(SEARCH_SIZE)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/***
|
||||
* suffix Search
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId, Set<Long> detectModelIds) {
|
||||
String reverseDetectSegment = StringUtils.reverse(key);
|
||||
return suffixSearch(reverseDetectSegment, limit, agentId, suffixTrie, detectModelIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Integer agentId,
|
||||
BinTrie<List<String>> binTrie, Set<Long> detectModelIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = prefixSearchLimit(key, limit, binTrie, agentId, detectModelIds);
|
||||
return result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
List<String> natures = entry.getValue().stream()
|
||||
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
|
||||
.collect(Collectors.toList());
|
||||
name = StringUtils.reverse(name);
|
||||
return new HanlpMapResult(name, natures, key);
|
||||
}
|
||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.limit(SEARCH_SIZE)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Map.Entry<String, List<String>>> prefixSearchLimit(String key, int limit,
|
||||
BinTrie<List<String>> binTrie, Integer agentId, Set<Long> detectModelIds) {
|
||||
key = key.toLowerCase();
|
||||
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (StringUtils.isNotBlank(key)) {
|
||||
sb = new StringBuilder(key.substring(0, key.length() - 1));
|
||||
}
|
||||
BaseNode branch = binTrie;
|
||||
char[] chars = key.toCharArray();
|
||||
for (char aChar : chars) {
|
||||
if (branch == null) {
|
||||
return entrySet;
|
||||
}
|
||||
branch = branch.getChild(aChar);
|
||||
}
|
||||
|
||||
if (branch == null) {
|
||||
return entrySet;
|
||||
}
|
||||
branch.walkLimit(sb, entrySet, limit, agentId, detectModelIds);
|
||||
return entrySet;
|
||||
}
|
||||
|
||||
public static void clear() {
|
||||
log.info("clear all trie");
|
||||
trie = new BinTrie<>();
|
||||
suffixTrie = new BinTrie<>();
|
||||
}
|
||||
|
||||
public static void put(String key, CoreDictionary.Attribute attribute) {
|
||||
trie.put(key, getValue(attribute.nature));
|
||||
}
|
||||
|
||||
public static void loadSuffix(List<DictWord> suffixes) {
|
||||
if (CollectionUtils.isEmpty(suffixes)) {
|
||||
return;
|
||||
}
|
||||
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
|
||||
for (DictWord suffix : suffixes) {
|
||||
CoreDictionary.Attribute attributeNew = suffix.getNatureWithFrequency() == null
|
||||
? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
|
||||
if (map.containsKey(suffix.getWord())) {
|
||||
attributeNew = DictionaryAttributeUtil.getAttribute(map.get(suffix.getWord()), attributeNew);
|
||||
}
|
||||
map.put(suffix.getWord(), attributeNew);
|
||||
}
|
||||
for (Map.Entry<String, CoreDictionary.Attribute> stringAttributeEntry : map.entrySet()) {
|
||||
putSuffix(stringAttributeEntry.getKey(), stringAttributeEntry.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
public static void putSuffix(String key, CoreDictionary.Attribute attribute) {
|
||||
Nature[] nature = attribute.nature;
|
||||
suffixTrie.put(key, getValue(nature));
|
||||
}
|
||||
|
||||
private static List<String> getValue(Nature[] nature) {
|
||||
return Arrays.stream(nature).map(entry -> entry.toString()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static void remove(DictWord dictWord, Nature[] natures) {
|
||||
trie.remove(dictWord.getWord());
|
||||
if (Objects.nonNull(natures) && natures.length > 0) {
|
||||
trie.put(dictWord.getWord(), getValue(natures));
|
||||
}
|
||||
if (dictWord.getNature().contains(DictWordType.METRIC.getType()) || dictWord.getNature()
|
||||
.contains(DictWordType.DIMENSION.getType())) {
|
||||
suffixTrie.remove(dictWord.getWord());
|
||||
}
|
||||
}
|
||||
|
||||
public static List<String> getDimensionValue(DimensionValueReq dimensionValueReq) {
|
||||
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT
|
||||
+ dimensionValueReq.getElementID();
|
||||
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
if (org.apache.commons.collections.CollectionUtils.isEmpty(terms)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* base word nature
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseWordBuilder {
|
||||
|
||||
public static final Long DEFAULT_FREQUENCY = 100000L;
|
||||
|
||||
public List<DictWord> getDictWords(List<SchemaElement> schemaElements) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
try {
|
||||
dictWords = getDictWordsWithException(schemaElements);
|
||||
} catch (Exception e) {
|
||||
log.error("getWordNatureList error,", e);
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
protected List<DictWord> getDictWordsWithException(List<SchemaElement> schemaElements) {
|
||||
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
dictWords.addAll(doGet(schemaElement.getName(), schemaElement));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
protected abstract List<DictWord> doGet(String word, SchemaElement schemaElement);
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* dimension word nature
|
||||
*/
|
||||
@Service
|
||||
public class DimensionWordBuilder extends BaseWordBuilder {
|
||||
|
||||
@Value("${nlp.dimension.use.suffix:true}")
|
||||
private boolean nlpDimensionUseSuffix = true;
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOnwWordNature(word, schemaElement, false));
|
||||
result.addAll(getOnwWordNatureAlias(schemaElement, false));
|
||||
if (nlpDimensionUseSuffix) {
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) {
|
||||
result.add(getOnwWordNature(reverseWord, schemaElement, true));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long domainId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + domainId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + domainId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOnwWordNature(alias, schemaElement, false));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* dimension value wordNature
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EntityWordBuilder extends BaseWordBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
Long domain = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + domain + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.ENTITY.getType();
|
||||
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
schemaElement.getAlias().stream().forEach(alias -> {
|
||||
DictWord dictWordAlias = new DictWord();
|
||||
dictWordAlias.setWord(alias);
|
||||
dictWordAlias.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));
|
||||
result.add(dictWordAlias);
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
*/
|
||||
@Service
|
||||
public class MetricWordBuilder extends BaseWordBuilder {
|
||||
|
||||
@Value("${nlp.metric.use.suffix:true}")
|
||||
private boolean nlpMetricUseSuffix = true;
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOnwWordNature(word, schemaElement, false));
|
||||
result.addAll(getOnwWordNatureAlias(schemaElement, false));
|
||||
if (nlpMetricUseSuffix) {
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (!word.equalsIgnoreCase(reverseWord)) {
|
||||
result.add(getOnwWordNature(reverseWord, schemaElement, true));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.METRIC.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOnwWordNature(alias, schemaElement, false));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* model word nature
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ModelWordBuilder extends BaseWordBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
//modelName
|
||||
DictWord dictWord = buildDictWord(word, schemaElement.getModel());
|
||||
result.add(dictWord);
|
||||
//alias
|
||||
List<String> aliasList = schemaElement.getAlias();
|
||||
if (CollectionUtils.isNotEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(buildDictWord(alias, schemaElement.getModel()));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DictWord buildDictWord(String word, Long modelId) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
String nature = DictWordType.NATURE_SPILT + modelId;
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* dimension value wordNature
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ValueWordBuilder extends BaseWordBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
if (Objects.nonNull(schemaElement) && !CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
|
||||
schemaElement.getAlias().stream().forEach(value -> {
|
||||
DictWord dictWord = new DictWord();
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId();
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
dictWord.setWord(value);
|
||||
result.add(dictWord);
|
||||
});
|
||||
}
|
||||
log.debug("ValueWordBuilder, result:{}", result);
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.builder;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* DictWord Strategy Factory
|
||||
*/
|
||||
public class WordBuilderFactory {
|
||||
|
||||
private static Map<DictWordType, BaseWordBuilder> wordNatures = new ConcurrentHashMap<>();
|
||||
|
||||
static {
|
||||
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
||||
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
||||
wordNatures.put(DictWordType.MODEL, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||
}
|
||||
|
||||
public static BaseWordBuilder get(DictWordType strategyType) {
|
||||
return wordNatures.get(strategyType);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
|
||||
|
||||
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
|
||||
|
||||
@SneakyThrows
|
||||
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
if (cacheEnable) {
|
||||
return modelSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
||||
modelSchemaCache.put(String.valueOf(ids), data);
|
||||
return data;
|
||||
});
|
||||
}
|
||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelSchema getModelSchema(Long model, Boolean cacheEnable) {
|
||||
List<Long> ids = new ArrayList<>();
|
||||
ids.add(model);
|
||||
List<ModelSchemaResp> modelSchemaResps = fetchModelSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||
Optional<ModelSchemaResp> modelSchemaResp = modelSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(model)).findFirst();
|
||||
if (modelSchemaResp.isPresent()) {
|
||||
ModelSchemaResp modelSchema = modelSchemaResp.get();
|
||||
return ModelSchemaBuilder.build(modelSchema);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchema> getModelSchema() {
|
||||
return getModelSchema(new ArrayList<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchema> getModelSchema(List<Long> ids) {
|
||||
List<ModelSchema> domainSchemaList = new ArrayList<>();
|
||||
|
||||
for (ModelSchemaResp resp : fetchModelSchema(ids, true)) {
|
||||
domainSchemaList.add(ModelSchemaBuilder.build(resp));
|
||||
}
|
||||
|
||||
return domainSchemaList;
|
||||
}
|
||||
|
||||
protected abstract List<ModelSchemaResp> doFetchModelSchema(List<Long> ids);
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
|
||||
private SchemaService schemaService;
|
||||
private DimensionService dimensionService;
|
||||
private MetricService metricService;
|
||||
private QueryService queryService;
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
|
||||
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
|
||||
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
|
||||
queryS2SQLReq.setSql(queryStructReq.getCorrectS2SQL());
|
||||
queryS2SQLReq.setModelIds(queryStructReq.getModelIdSet());
|
||||
queryS2SQLReq.setVariables(new HashMap<>());
|
||||
return queryByS2SQL(queryS2SQLReq, user);
|
||||
}
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.queryByStructWithAuth(queryStructReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) {
|
||||
try {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.queryByMultiStruct(queryMultiStructReq, user);
|
||||
} catch (Exception e) {
|
||||
log.info("queryByMultiStruct has an exception:{}", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user) {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
Object object = queryService.queryBySql(queryS2SQLReq, user);
|
||||
return JsonUtil.toObject(JsonUtil.toString(object), QueryResultWithSchemaResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.queryDimValue(queryDimValueReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
|
||||
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
|
||||
filter.setModelIds(ids);
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
User user = User.getFakeUser();
|
||||
return schemaService.fetchModelSchema(filter, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DomainResp> getDomainList(User user) {
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.getDomainList(user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelResp> getModelList(AuthType authType, Long domainId, User user) {
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.getModelList(user, authType, domainId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.explain(explainSqlReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
|
||||
dimensionService = ContextUtils.getBean(DimensionService.class);
|
||||
return dimensionService.queryDimension(pageDimensionCmd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricReq, User user) {
|
||||
metricService = ContextUtils.getBean(MetricService.class);
|
||||
return metricService.queryMetric(pageMetricReq, user);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelSchemaBuilder {
|
||||
|
||||
public static ModelSchema build(ModelSchemaResp resp) {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
SchemaElement model = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(resp.getId())
|
||||
.name(resp.getName())
|
||||
.bizName(resp.getBizName())
|
||||
.type(SchemaElementType.MODEL)
|
||||
.alias(SchemaItem.getAliasList(resp.getAlias()))
|
||||
.build();
|
||||
modelSchema.setModel(model);
|
||||
modelSchema.setModelRelas(resp.getModelRelas());
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
for (MetricSchemaResp metric : resp.getMetrics()) {
|
||||
|
||||
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
|
||||
|
||||
SchemaElement metricToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(metric.getId())
|
||||
.name(metric.getName())
|
||||
.bizName(metric.getBizName())
|
||||
.type(SchemaElementType.METRIC)
|
||||
.useCnt(metric.getUseCnt())
|
||||
.alias(alias)
|
||||
.relatedSchemaElements(getRelateSchemaElement(metric))
|
||||
.defaultAgg(metric.getDefaultAgg())
|
||||
.build();
|
||||
metrics.add(metricToAdd);
|
||||
|
||||
}
|
||||
modelSchema.getMetrics().addAll(metrics);
|
||||
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
Set<SchemaElement> tags = new HashSet<>();
|
||||
for (DimSchemaResp dim : resp.getDimensions()) {
|
||||
|
||||
List<String> alias = SchemaItem.getAliasList(dim.getAlias());
|
||||
Set<String> dimValueAlias = new HashSet<>();
|
||||
List<DimValueMap> dimValueMaps = dim.getDimValueMaps();
|
||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
||||
if (!CollectionUtils.isEmpty(dimValueMaps)) {
|
||||
|
||||
for (DimValueMap dimValueMap : dimValueMaps) {
|
||||
if (Strings.isNotEmpty(dimValueMap.getBizName())) {
|
||||
dimValueAlias.add(dimValueMap.getBizName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(dimValueMap.getAlias())) {
|
||||
dimValueAlias.addAll(dimValueMap.getAlias());
|
||||
}
|
||||
SchemaValueMap schemaValueMap = new SchemaValueMap();
|
||||
BeanUtils.copyProperties(dimValueMap, schemaValueMap);
|
||||
schemaValueMaps.add(schemaValueMap);
|
||||
}
|
||||
|
||||
}
|
||||
SchemaElement dimToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(alias)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(dimToAdd);
|
||||
|
||||
SchemaElement dimValueToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
.type(SchemaElementType.VALUE)
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0]))))
|
||||
.build();
|
||||
dimensionValues.add(dimValueToAdd);
|
||||
if (dim.getIsTag() == 1) {
|
||||
SchemaElement tagToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
.type(SchemaElementType.TAG)
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(alias)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
tags.add(tagToAdd);
|
||||
}
|
||||
}
|
||||
modelSchema.getDimensions().addAll(dimensions);
|
||||
modelSchema.getDimensionValues().addAll(dimensionValues);
|
||||
modelSchema.getTags().addAll(tags);
|
||||
|
||||
DimSchemaResp dim = resp.getPrimaryKey();
|
||||
if (dim != null) {
|
||||
SchemaElement entity = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
.type(SchemaElementType.ENTITY)
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(dim.getEntityAlias())
|
||||
.build();
|
||||
modelSchema.setEntity(entity);
|
||||
}
|
||||
return modelSchema;
|
||||
}
|
||||
|
||||
private static List<RelatedSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {
|
||||
RelateDimension relateDimension = metricSchemaResp.getRelateDimension();
|
||||
if (relateDimension == null || CollectionUtils.isEmpty(relateDimension.getDrillDownDimensions())) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return relateDimension.getDrillDownDimensions().stream().map(dimension -> {
|
||||
RelatedSchemaElement relateSchemaElement = new RelatedSchemaElement();
|
||||
BeanUtils.copyProperties(dimension, relateSchemaElement);
|
||||
return relateSchemaElement;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||
import com.tencent.supersonic.auth.api.authentication.constant.UserConstants;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.config.DefaultSemanticConfig;
|
||||
import com.tencent.supersonic.common.pojo.ResultData;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.pojo.enums.ReturnCode;
|
||||
import com.tencent.supersonic.common.pojo.exception.CommonException;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||
import com.tencent.supersonic.common.util.ThreadContext;
|
||||
import com.tencent.supersonic.headless.api.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
|
||||
private S2ThreadContext s2ThreadContext;
|
||||
|
||||
private AuthenticationConfig authenticationConfig;
|
||||
|
||||
private ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>> structTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<QueryResultWithSchemaResp>>() {
|
||||
};
|
||||
|
||||
private ParameterizedTypeReference<ResultData<ExplainResp>> explainTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<ExplainResp>>() {
|
||||
};
|
||||
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
|
||||
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
|
||||
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
|
||||
queryS2SQLReq.setSql(queryStructReq.getCorrectS2SQL());
|
||||
queryS2SQLReq.setModelIds(queryStructReq.getModelIdSet());
|
||||
queryS2SQLReq.setVariables(new HashMap<>());
|
||||
return queryByS2SQL(queryS2SQLReq, user);
|
||||
}
|
||||
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchByStructPath(),
|
||||
new Gson().toJson(queryStructReq));
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchByMultiStructPath(),
|
||||
new Gson().toJson(queryMultiStructReq));
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchBySqlPath(),
|
||||
new Gson().toJson(queryS2SQLReq));
|
||||
}
|
||||
|
||||
public QueryResultWithSchemaResp searchByRestTemplate(String url, String jsonReq) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
|
||||
log.info("url:{},searchByRestTemplate:{}", url, entity.getBody());
|
||||
ResultData<QueryResultWithSchemaResp> responseBody;
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
|
||||
ResponseEntity<ResultData<QueryResultWithSchemaResp>> responseEntity = restTemplate.exchange(
|
||||
requestUrl, HttpMethod.POST, entity, structTypeRef);
|
||||
responseBody = responseEntity.getBody();
|
||||
log.info("ApiResponse<QueryResultWithColumns> responseBody:{}", responseBody);
|
||||
QueryResultWithSchemaResp schemaResp = new QueryResultWithSchemaResp();
|
||||
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
|
||||
QueryResultWithSchemaResp data = responseBody.getData();
|
||||
schemaResp.setColumns(data.getColumns());
|
||||
schemaResp.setResultList(data.getResultList());
|
||||
schemaResp.setSql(data.getSql());
|
||||
schemaResp.setQueryAuthorization(data.getQueryAuthorization());
|
||||
return schemaResp;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("search headless interface error,url:" + url, e);
|
||||
}
|
||||
throw new CommonException(responseBody.getCode(), responseBody.getMsg());
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl()
|
||||
+ defaultSemanticConfig.getQueryDimValuePath(),
|
||||
new Gson().toJson(queryDimValueReq));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.set(UserConstants.INTERNAL, TRUE_LOWER);
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
|
||||
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
|
||||
String fetchModelSchemaPath = defaultSemanticConfig.getFetchModelSchemaPath();
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(semanticUrl + fetchModelSchemaPath)
|
||||
.build().encode().toUri();
|
||||
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
|
||||
filter.setModelIds(ids);
|
||||
ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>> responseTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>>() {
|
||||
};
|
||||
|
||||
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(filter), headers);
|
||||
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<ResultData<List<ModelSchemaResp>>> responseEntity = restTemplate.exchange(
|
||||
requestUrl, HttpMethod.POST, entity, responseTypeRef);
|
||||
ResultData<List<ModelSchemaResp>> responseBody = responseEntity.getBody();
|
||||
log.debug("ApiResponse<fetchModelSchema> responseBody:{}", responseBody);
|
||||
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
|
||||
List<ModelSchemaResp> data = responseBody.getData();
|
||||
return data;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("fetchModelSchema interface error", e);
|
||||
}
|
||||
throw new RuntimeException("fetchModelSchema interface error");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DomainResp> getDomainList(User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
Object domainDescListObject = fetchHttpResult(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchDomainListPath(),
|
||||
null, HttpMethod.GET);
|
||||
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), DomainResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelResp> getModelList(AuthType authType, Long domainId, User user) {
|
||||
if (domainId == null) {
|
||||
domainId = 0L;
|
||||
}
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
String url = String.format("%s?domainId=%s&authType=%s",
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchModelListPath(),
|
||||
domainId, authType.toString());
|
||||
Object domainDescListObject = fetchHttpResult(url, null, HttpMethod.GET);
|
||||
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), ModelResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> ExplainResp explain(ExplainSqlReq<T> explainResp, User user) throws Exception {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
|
||||
String explainPath = defaultSemanticConfig.getExplainPath();
|
||||
URL url = new URL(new URL(semanticUrl), explainPath);
|
||||
return explain(url.toString(), JsonUtil.toString(explainResp));
|
||||
}
|
||||
|
||||
public ExplainResp explain(String url, String jsonReq) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
|
||||
log.info("url:{},explain:{}", url, entity.getBody());
|
||||
ResultData<ExplainResp> responseBody;
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
|
||||
ResponseEntity<ResultData<ExplainResp>> responseEntity = restTemplate.exchange(
|
||||
requestUrl, HttpMethod.POST, entity, explainTypeRef);
|
||||
log.info("ApiResponse<ExplainResp> responseBody:{}", responseEntity);
|
||||
responseBody = responseEntity.getBody();
|
||||
if (Objects.nonNull(responseBody.getData())) {
|
||||
return responseBody.getData();
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("explain interface error,url:" + url, e);
|
||||
}
|
||||
}
|
||||
|
||||
public Object fetchHttpResult(String url, String bodyJson, HttpMethod httpMethod) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
ParameterizedTypeReference<ResultData<Object>> responseTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<Object>>() {
|
||||
};
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(bodyJson), headers);
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<ResultData<Object>> responseEntity = restTemplate.exchange(requestUrl,
|
||||
httpMethod, entity, responseTypeRef);
|
||||
ResultData<Object> responseBody = responseEntity.getBody();
|
||||
log.debug("ApiResponse<fetchModelSchema> responseBody:{}", responseBody);
|
||||
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
|
||||
Object data = responseBody.getData();
|
||||
return data;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("fetchModelSchema interface error", e);
|
||||
}
|
||||
throw new RuntimeException("fetchModelSchema interface error");
|
||||
}
|
||||
|
||||
public void fillToken(HttpHeaders headers) {
|
||||
s2ThreadContext = ContextUtils.getBean(S2ThreadContext.class);
|
||||
authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class);
|
||||
ThreadContext threadContext = s2ThreadContext.get();
|
||||
if (Objects.nonNull(threadContext) && Strings.isNotEmpty(threadContext.getToken())) {
|
||||
if (Objects.nonNull(authenticationConfig) && Strings.isNotEmpty(
|
||||
authenticationConfig.getTokenHttpHeaderKey())) {
|
||||
headers.set(authenticationConfig.getTokenHttpHeaderKey(), threadContext.getToken());
|
||||
}
|
||||
} else {
|
||||
log.debug("threadContext is null:{}", Objects.isNull(threadContext));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user) {
|
||||
String body = JsonUtil.toString(pageMetricCmd);
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
log.info("url:{}", defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath());
|
||||
Object dimensionListObject = fetchHttpResult(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath(),
|
||||
body, HttpMethod.POST);
|
||||
LinkedHashMap map = (LinkedHashMap) dimensionListObject;
|
||||
PageInfo<Object> metricDescObjectPageInfo = generatePageInfo(map);
|
||||
PageInfo<MetricResp> metricDescPageInfo = new PageInfo<>();
|
||||
BeanUtils.copyProperties(metricDescObjectPageInfo, metricDescPageInfo);
|
||||
metricDescPageInfo.setList(metricDescPageInfo.getList());
|
||||
return metricDescPageInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
|
||||
String body = JsonUtil.toString(pageDimensionCmd);
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
Object dimensionListObject = fetchHttpResult(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchDimensionPagePath(),
|
||||
body, HttpMethod.POST);
|
||||
LinkedHashMap map = (LinkedHashMap) dimensionListObject;
|
||||
PageInfo<Object> dimensionDescObjectPageInfo = generatePageInfo(map);
|
||||
PageInfo<DimensionResp> dimensionDescPageInfo = new PageInfo<>();
|
||||
BeanUtils.copyProperties(dimensionDescObjectPageInfo, dimensionDescPageInfo);
|
||||
dimensionDescPageInfo.setList(dimensionDescPageInfo.getList());
|
||||
return dimensionDescPageInfo;
|
||||
}
|
||||
|
||||
private PageInfo<Object> generatePageInfo(LinkedHashMap map) {
|
||||
PageInfo<Object> pageInfo = new PageInfo<>();
|
||||
pageInfo.setList((List<Object>) map.get(LIST_LOWER));
|
||||
Integer total = (Integer) map.get(TOTAL_LOWER);
|
||||
pageInfo.setTotal(total);
|
||||
Integer pageSize = (Integer) map.get(PAGESIZE_LOWER);
|
||||
pageInfo.setPageSize(pageSize);
|
||||
pageInfo.setPages((int) Math.ceil((double) total / pageSize));
|
||||
return pageInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A semantic layer provides a simplified and consistent view of data from multiple sources.
|
||||
* It abstracts away the complexity of the underlying data sources and provides a unified view
|
||||
* of the data that is easier to understand and use.
|
||||
* <p>
|
||||
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
|
||||
* Implementations of this interface should provide concrete implementations that interact with the
|
||||
* underlying data sources and return results in a consistent format. Or it can be implemented
|
||||
* as proxy to a remote semantic service.
|
||||
* </p>
|
||||
*/
|
||||
public interface SemanticInterpreter {
|
||||
|
||||
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
|
||||
List<ModelSchema> getModelSchema();
|
||||
|
||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
||||
|
||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||
|
||||
List<DomainResp> getDomainList(User user);
|
||||
|
||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
}
|
||||
@@ -1,14 +1,12 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -68,10 +66,10 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
return null;
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -102,7 +102,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -33,14 +32,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
this.allElements = getSchemaElements();
|
||||
this.allElements = getSchemaElements(queryContext);
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
}
|
||||
|
||||
@@ -62,7 +59,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
@@ -90,10 +87,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElement> getSchemaElements() {
|
||||
private List<SchemaElement> getSchemaElements(QueryContext queryContext) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
||||
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
||||
allElements.addAll(queryContext.getSemanticSchema().getDimensions());
|
||||
allElements.addAll(queryContext.getSemanticSchema().getMetrics());
|
||||
return allElements;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -43,7 +43,8 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
long modelId = Long.parseLong(modelIdStr);
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId,
|
||||
queryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -1,28 +1,28 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.server.listener.MetaEmbeddingListener;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* EmbeddingMatchStrategy uses vector database to perform
|
||||
@@ -35,6 +35,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
@Override
|
||||
@@ -86,7 +89,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
.build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||
embeddingConfig.getMetaCollectionName(), retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
@@ -1,21 +1,19 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
@@ -30,7 +28,7 @@ public class EntityMapper extends BaseMapper {
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(modelId);
|
||||
SchemaElement entity = getEntity(modelId, queryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -66,9 +64,9 @@ public class EntityMapper extends BaseMapper {
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long modelId) {
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
|
||||
private SchemaElement getEntity(Long modelId, QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
@@ -69,8 +69,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||
agentId,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -35,7 +35,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext.getMapInfo(), terms);
|
||||
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms);
|
||||
|
||||
//2.database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
@@ -44,7 +44,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, SchemaMapInfo schemaMap,
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||
List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
@@ -65,7 +65,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID,
|
||||
queryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -81,7 +82,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,11 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -83,15 +82,12 @@ public class MapperHelper {
|
||||
detectSegment.length());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(QueryReq request) {
|
||||
public Set<Long> getModelIds(QueryReq request, Agent agent) {
|
||||
|
||||
Long modelId = request.getModelId();
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
|
||||
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
|
||||
Set<Long> detectModelIds = agent.getModelIds(null);
|
||||
//contains all
|
||||
if (agentService.containsAllModel(detectModelIds)) {
|
||||
if (agent.containsAllModel(detectModelIds)) {
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(modelId);
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import java.util.Objects;
|
||||
import lombok.Builder;
|
||||
@@ -1,16 +1,12 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
@@ -27,8 +23,7 @@ public class ModelClusterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
||||
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
||||
@@ -46,7 +41,7 @@ public class ModelClusterMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
||||
SemanticSchema semanticSchema) {
|
||||
SemanticSchema semanticSchema) {
|
||||
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
||||
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import java.io.Serializable;
|
||||
@@ -1,22 +1,21 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterMapper implements SchemaMapper {
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
|
||||
/**
|
||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
||||
* in user queries. It matches the query text against the knowledge base.
|
||||
*/
|
||||
public interface SchemaMapper {
|
||||
|
||||
void map(QueryContext queryContext);
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import java.util.Objects;
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||
* could be used.
|
||||
*/
|
||||
public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
@@ -1,30 +1,25 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
@@ -40,14 +35,14 @@ public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(user);
|
||||
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(semanticQuery);
|
||||
QueryType queryType = getQueryType(queryContext, semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
||||
private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
@@ -55,13 +50,13 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
@@ -77,7 +72,6 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
|
||||
/**
|
||||
* A semantic parser understands user queries and extracts semantic information.
|
||||
* It could leverage either rule-based or LLM-based approach to identify query intent
|
||||
* and extract related semantic items from the query.
|
||||
*/
|
||||
public interface SemanticParser {
|
||||
|
||||
void parse(QueryContext queryContext, ChatContext chatContext);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||
|
||||
public enum ParseMode {
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
@@ -75,7 +75,7 @@ public abstract class PluginParser implements SemanticParser {
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -1,29 +1,26 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
||||
*/
|
||||
@@ -45,19 +42,17 @@ public class FunctionCallParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
FunctionResp functionResp = functionCall(queryContext);
|
||||
if (skipFunction(functionResp)) {
|
||||
return null;
|
||||
}
|
||||
log.info("requestFunction result:{}", functionResp.getToolSelection());
|
||||
String toolSelection = functionResp.getToolSelection();
|
||||
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
|
||||
if (!pluginOptional.isPresent()) {
|
||||
Plugin plugin = queryContext.getNameToPlugin().get(toolSelection);
|
||||
if (Objects.isNull(plugin)) {
|
||||
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
|
||||
return null;
|
||||
}
|
||||
Plugin plugin = pluginOptional.get();
|
||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||
if (pluginResolveResult.getLeft()) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.InputFormat;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.InputFormat;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@@ -1,28 +1,26 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import java.util.ArrayList;
|
||||
@@ -49,10 +47,6 @@ public class LLMRequestService {
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
public boolean isSkip(QueryContext queryCtx) {
|
||||
@@ -66,9 +60,9 @@ public class LLMRequestService {
|
||||
return false;
|
||||
}
|
||||
|
||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.NL2SQL_LLM);
|
||||
if (agentService.containsAllModel(distinctModelIds)) {
|
||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
Set<Long> distinctModelIds = queryCtx.getAgent().getModelIds(AgentToolType.NL2SQL_LLM);
|
||||
if (queryCtx.getAgent().containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
@@ -77,13 +71,13 @@ public class LLMRequestService {
|
||||
return ModelCluster.build(modelCluster);
|
||||
}
|
||||
|
||||
public NL2SQLTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
|
||||
List<NL2SQLTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
|
||||
AgentToolType.NL2SQL_LLM);
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Set<Long> modelIdSet) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
|
||||
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
||||
if (agent.containsAllModel(new HashSet<>(modelIds))) {
|
||||
return true;
|
||||
}
|
||||
for (Long modelId : modelIdSet) {
|
||||
@@ -127,7 +121,7 @@ public class LLMRequestService {
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(firstModelId);
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, firstModelId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
@@ -143,7 +137,7 @@ public class LLMRequestService {
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
|
||||
Set<String> results = getTopNFieldNames(queryCtx, modelCluster, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
|
||||
|
||||
@@ -187,7 +181,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
@@ -210,14 +204,15 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(ModelCluster modelCluster) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(ModelCluster modelCluster, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
@@ -235,7 +230,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user