GPU
This commit is contained in:
@@ -14,7 +14,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.SpringApplication;
|
||||
@@ -49,9 +49,10 @@ public class RagApplication {
|
||||
.order(3).build(),
|
||||
SimpleLoggerAdvisor.builder().order(4).build()
|
||||
)
|
||||
.defaultOptions(OllamaOptions.builder()
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.temperature(ragDefaults.temperature())
|
||||
.repeatPenalty(ragDefaults.repeatPenalty())
|
||||
.topP(ragDefaults.topP())
|
||||
.frequencyPenalty(ragDefaults.repeatPenalty() - 1.0) // Ollama repeatPenalty 1.1 -> frequencyPenalty 0.1
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
@@ -71,4 +72,4 @@ public class RagApplication {
|
||||
SpringApplication.run(RagApplication.class, args);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@@ -48,11 +48,10 @@ public class ExpansionQueryAdvisor implements BaseAdvisor {
|
||||
|
||||
public static ExpansionQueryAdvisorBuilder builder(ChatModel chatModel, RagExpansionProperties props) {
|
||||
return new ExpansionQueryAdvisorBuilder().chatClient(ChatClient.builder(chatModel)
|
||||
.defaultOptions(OllamaOptions.builder()
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.temperature(props.temperature())
|
||||
.topK(props.topK())
|
||||
.topP(props.topP())
|
||||
.repeatPenalty(props.repeatPenalty())
|
||||
.frequencyPenalty(props.repeatPenalty() - 1.0) // Ollama repeatPenalty 1.0 -> frequencyPenalty 0.0
|
||||
.build())
|
||||
.build());
|
||||
}
|
||||
@@ -86,4 +85,4 @@ public class ExpansionQueryAdvisor implements BaseAdvisor {
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -36,16 +36,15 @@ public class ChatEntryController {
|
||||
log.trace(ApiLogMessage.NAME_OF_CURRENT_METHOD.getValue(), ApiUtils.getMethodName());
|
||||
|
||||
boolean onlyContext = request.onlyContext() != null ? request.onlyContext() : ragDefaults.onlyContext();
|
||||
int topK = request.topK() != null ? request.topK() : ragDefaults.topK();
|
||||
double topP = request.topP() != null ? request.topP() : ragDefaults.topP();
|
||||
|
||||
ChatEntry entry = chatEntryService.addUserEntry(chatId, request.content(), onlyContext, topK, topP);
|
||||
ChatEntry entry = chatEntryService.addUserEntry(chatId, request.content(), onlyContext, topP);
|
||||
|
||||
Chat chat = chatService.getChat(chatId);
|
||||
eventPublisher.publishQuerySent(
|
||||
chat.getIdOwner().toString(),
|
||||
chatId.toString(),
|
||||
0); // TODO: add tokensUsed when Ollama response provides it
|
||||
0); // TODO: add tokensUsed when usage info is available from Groq response
|
||||
|
||||
return ResponseEntity.ok(entry);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,5 @@ package com.balex.rag.model.dto;
|
||||
public record UserEntryRequest(
|
||||
String content,
|
||||
Boolean onlyContext,
|
||||
Integer topK,
|
||||
Double topP
|
||||
) {}
|
||||
) {}
|
||||
@@ -8,5 +8,5 @@ public interface ChatEntryService {
|
||||
|
||||
List<ChatEntry> getEntriesByChatId(Long chatId);
|
||||
|
||||
ChatEntry addUserEntry(Long chatId, String content, boolean onlyContext, int topK, double topP);
|
||||
ChatEntry addUserEntry(Long chatId, String content, boolean onlyContext, double topP);
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@@ -33,7 +33,7 @@ public class ChatEntryServiceImpl implements ChatEntryService {
|
||||
|
||||
@Override
|
||||
@Transactional
|
||||
public ChatEntry addUserEntry(Long chatId, String content, boolean onlyContext, int topK, double topP) {
|
||||
public ChatEntry addUserEntry(Long chatId, String content, boolean onlyContext, double topP) {
|
||||
Chat chat = chatRepository.findById(chatId)
|
||||
.orElseThrow(() -> new EntityNotFoundException("Chat not found with id: " + chatId));
|
||||
|
||||
@@ -63,8 +63,7 @@ public class ChatEntryServiceImpl implements ChatEntryService {
|
||||
.system(systemPrompt)
|
||||
.user(content)
|
||||
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, String.valueOf(chatId)))
|
||||
.options(OllamaOptions.builder()
|
||||
.topK(topK)
|
||||
.options(OpenAiChatOptions.builder()
|
||||
.topP(topP)
|
||||
.build())
|
||||
.call()
|
||||
|
||||
Reference in New Issue
Block a user