[fix]Fix unit test cases.

This commit is contained in:
jerryjzhang
2025-08-05 17:22:10 +08:00
parent bf3213e8fb
commit 91e4b51ef8
22 changed files with 197 additions and 115 deletions

View File

@@ -86,7 +86,7 @@ public class FileHandlerImpl implements FileHandler {
}
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
DictValueReq dictValueReq) {
DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
@@ -95,7 +95,7 @@ public class FileHandlerImpl implements FileHandler {
Integer startLine = 1;
List<DictValueResp> dictValueRespList =
getFileData(filePath, startLine, fileLineNum.intValue()).stream().filter(
dictValue -> dictValue.getValue().contains(dictValueReq.getKeyValue()))
dictValue -> dictValue.getValue().contains(dictValueReq.getKeyValue()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dictValueRespList)) {
dictValueRespPageInfo.setList(new ArrayList<>());
@@ -118,7 +118,7 @@ public class FileHandlerImpl implements FileHandler {
}
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
DictValueReq dictValueReq) {
DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
Long fileLineNum = getFileLineNum(filePath);
@@ -175,7 +175,7 @@ public class FileHandlerImpl implements FileHandler {
private DictValueResp convert2Resp(String lineStr) {
DictValueResp dictValueResp = new DictValueResp();
if (StringUtils.isNotEmpty(lineStr)) {
lineStr=StringUtils.stripStart(lineStr,null);
lineStr = StringUtils.stripStart(lineStr, null);
String[] itemArray = lineStr.split("\\s+");
if (Objects.nonNull(itemArray) && itemArray.length >= 3) {
dictValueResp.setValue(itemArray[0].replace("#", " "));

View File

@@ -63,7 +63,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
@Override
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
if (chatQueryContext == null || CollectionUtils.isEmpty(detectDataSetIds)) {
log.warn("Invalid input parameters: context={}, dataSetIds={}", chatQueryContext,
detectDataSetIds);
@@ -92,7 +92,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* Perform enhanced detection using LLM
*/
private List<EmbeddingResult> detectWithLLM(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
try {
String queryText = chatQueryContext.getRequest().getQueryText();
if (StringUtils.isBlank(queryText)) {
@@ -126,7 +126,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
@Override
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
Set<Long> detectDataSetIds, Set<String> detectSegments) {
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, false);
}
@@ -140,7 +140,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* @return List of embedding results
*/
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
@@ -168,10 +168,11 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
variable.put("retrievedInfo", JSONObject.toJSONString(results));
Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable);
ChatModelConfig chatModelConfig=null;
if(chatQueryContext.getRequest().getChatAppConfig()!=null
&& chatQueryContext.getRequest().getChatAppConfig().containsKey("REWRITE_MULTI_TURN")){
chatModelConfig=chatQueryContext.getRequest().getChatAppConfig().get("REWRITE_MULTI_TURN").getChatModelConfig();
ChatModelConfig chatModelConfig = null;
if (chatQueryContext.getRequest().getChatAppConfig() != null && chatQueryContext
.getRequest().getChatAppConfig().containsKey("REWRITE_MULTI_TURN")) {
chatModelConfig = chatQueryContext.getRequest().getChatAppConfig()
.get("REWRITE_MULTI_TURN").getChatModelConfig();
}
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatModelConfig);
String response = chatLanguageModel.generate(prompt.toUserMessage().singleText());
@@ -200,7 +201,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* @return Callable task
*/
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
return () -> {
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
queryTextsSub, chatQueryContext, useLlm);
@@ -221,7 +222,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* @return List of embedding results for this batch
*/
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
// Get configuration parameters
@@ -243,12 +244,12 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
// Process results
List<EmbeddingResult> collect = retrieveQueryResults.stream().peek(result -> {
if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
result.getRetrieval()
.removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
&& retrieval.getSimilarity() < threshold);
}
}).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
result.getRetrieval()
.removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
&& retrieval.getSimilarity() < threshold);
}
}).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
.flatMap(result -> result.getRetrieval().stream()
.map(retrieval -> convertToEmbeddingResult(result, retrieval)))
.collect(Collectors.toList());
@@ -267,7 +268,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* @return Converted EmbeddingResult
*/
private EmbeddingResult convertToEmbeddingResult(RetrieveQueryResult queryResult,
Retrieval retrieval) {
Retrieval retrieval) {
EmbeddingResult embeddingResult = new EmbeddingResult();
BeanUtils.copyProperties(retrieval, embeddingResult);
embeddingResult.setDetectWord(queryResult.getQuery());

View File

@@ -51,7 +51,7 @@ public class KeywordMapper extends BaseMapper {
}
private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
ChatQueryContext chatQueryContext, List<S2Term> terms) {
ChatQueryContext chatQueryContext, List<S2Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
}
@@ -87,14 +87,15 @@ public class KeywordMapper extends BaseMapper {
.similarity(hanlpMapResult.getSimilarity())
.detectWord(hanlpMapResult.getDetectWord()).build();
// doDimValueAliasLogic 将维度值别名进行替换成真实维度值
doDimValueAliasLogic(schemaElementMatch,chatQueryContext.getSemanticSchema().getDimensionValues());
doDimValueAliasLogic(schemaElementMatch,
chatQueryContext.getSemanticSchema().getDimensionValues());
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}
private void doDimValueAliasLogic(SchemaElementMatch schemaElementMatch,
List<SchemaElement> dimensionValues) {
List<SchemaElement> dimensionValues) {
SchemaElement element = schemaElementMatch.getElement();
if (SchemaElementType.VALUE.equals(element.getType())) {
Long dimId = element.getId();
@@ -126,7 +127,7 @@ public class KeywordMapper extends BaseMapper {
}
private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext,
List<DatabaseMapResult> mapResults) {
List<DatabaseMapResult> mapResults) {
for (DatabaseMapResult match : mapResults) {
SchemaElement schemaElement = match.getSchemaElement();
Set<Long> regElementSet =
@@ -153,8 +154,8 @@ public class KeywordMapper extends BaseMapper {
return new HashSet<>();
}
return elements.stream().filter(
elementMatch -> SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
elementMatch -> SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.getElement().getId()).collect(Collectors.toSet());
}
}