From 3b9f80b46c5c6e2503fc1a39613c8bc6c1fecb3a Mon Sep 17 00:00:00 2001 From: balex Date: Sun, 22 Feb 2026 02:52:44 +0100 Subject: [PATCH] vector store by user --- .../java/com/balex/rag/advisors/rag/RagAdvisor.java | 13 +++++++++---- .../balex/rag/controller/ChatEntryController.java | 4 ++-- .../com/balex/rag/service/ChatEntryService.java | 2 +- .../rag/service/impl/ChatEntryServiceImpl.java | 4 +++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/rag-service/src/main/java/com/balex/rag/advisors/rag/RagAdvisor.java b/rag-service/src/main/java/com/balex/rag/advisors/rag/RagAdvisor.java index 721a3d4..2f9d442 100644 --- a/rag-service/src/main/java/com/balex/rag/advisors/rag/RagAdvisor.java +++ b/rag-service/src/main/java/com/balex/rag/advisors/rag/RagAdvisor.java @@ -43,13 +43,18 @@ public class RagAdvisor implements BaseAdvisor { String originalUserQuestion = chatClientRequest.prompt().getUserMessage().getText(); String queryToRag = chatClientRequest.context().getOrDefault(ENRICHED_QUESTION, originalUserQuestion).toString(); - SearchRequest searchRequest = SearchRequest.builder() + Object userIdObj = chatClientRequest.context().get("USER_ID"); + + SearchRequest.Builder searchBuilder = SearchRequest.builder() .query(queryToRag) .topK(searchTopK * rerankFetchMultiplier) - .similarityThreshold(similarityThreshold) - .build(); + .similarityThreshold(similarityThreshold); - List documents = vectorStore.similaritySearch(searchRequest); + if (userIdObj != null) { + searchBuilder.filterExpression("user_id == " + userIdObj); + } + + List documents = vectorStore.similaritySearch(searchBuilder.build()); if (documents == null || documents.isEmpty()) { return chatClientRequest.mutate().context("CONTEXT", "EMPTY").build(); diff --git a/rag-service/src/main/java/com/balex/rag/controller/ChatEntryController.java b/rag-service/src/main/java/com/balex/rag/controller/ChatEntryController.java index 3e8a003..257eb24 100644 --- a/rag-service/src/main/java/com/balex/rag/controller/ChatEntryController.java +++ b/rag-service/src/main/java/com/balex/rag/controller/ChatEntryController.java @@ -38,9 +38,9 @@ public class ChatEntryController { boolean onlyContext = request.onlyContext() != null ? request.onlyContext() : ragDefaults.onlyContext(); double topP = request.topP() != null ? request.topP() : ragDefaults.topP(); - ChatEntry entry = chatEntryService.addUserEntry(chatId, request.content(), onlyContext, topP); - Chat chat = chatService.getChat(chatId); + ChatEntry entry = chatEntryService.addUserEntry(chatId, request.content(), onlyContext, topP, chat.getIdOwner()); + eventPublisher.publishQuerySent( chat.getIdOwner().toString(), chatId.toString(), diff --git a/rag-service/src/main/java/com/balex/rag/service/ChatEntryService.java b/rag-service/src/main/java/com/balex/rag/service/ChatEntryService.java index 7c58f49..be9723e 100644 --- a/rag-service/src/main/java/com/balex/rag/service/ChatEntryService.java +++ b/rag-service/src/main/java/com/balex/rag/service/ChatEntryService.java @@ -8,5 +8,5 @@ public interface ChatEntryService { List getEntriesByChatId(Long chatId); - ChatEntry addUserEntry(Long chatId, String content, boolean onlyContext, double topP); + ChatEntry addUserEntry(Long chatId, String content, boolean onlyContext, double topP, Long userId); } \ No newline at end of file diff --git a/rag-service/src/main/java/com/balex/rag/service/impl/ChatEntryServiceImpl.java b/rag-service/src/main/java/com/balex/rag/service/impl/ChatEntryServiceImpl.java index 44f95d6..55e875c 100644 --- a/rag-service/src/main/java/com/balex/rag/service/impl/ChatEntryServiceImpl.java +++ b/rag-service/src/main/java/com/balex/rag/service/impl/ChatEntryServiceImpl.java @@ -64,7 +64,9 @@ public class ChatEntryServiceImpl implements ChatEntryService { String response = chatClient.prompt() .system(systemPrompt) .user(content) - .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, String.valueOf(chatId))) + .advisors(a -> a + .param(ChatMemory.CONVERSATION_ID, String.valueOf(chatId)) + .param("USER_ID", userId)) .options(OpenAiChatOptions.builder() .model(ragDefaults.model()) .topP(topP)