Files
play-life/play-life-llm/internal/handler/ask.go
2026-02-08 17:01:36 +03:00

178 lines
6.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"encoding/json"
"log"
"net/http"
"play-life-llm/internal/ollama"
"play-life-llm/internal/tavily"
)
// AskRequest is the POST /ask body.
type AskRequest struct {
Prompt string `json:"prompt"`
ResponseSchema interface{} `json:"response_schema"`
Model string `json:"model,omitempty"`
// AllowWebSearch: если true, в запрос к Ollama добавляются tools (web_search), и при вызове модели выполняется поиск через Tavily. Если false (по умолчанию), tools не передаются — модель просто возвращает JSON по схеме (подходит для простых запросов без интернета).
AllowWebSearch bool `json:"allow_web_search,omitempty"`
}
// AskResponse is the successful response (result is JSON by schema).
type AskResponse struct {
Result json.RawMessage `json:"result"`
}
// AskHandler handles POST /ask: prompt + response_schema -> LLM with optional web search, returns JSON.
type AskHandler struct {
Ollama *ollama.Client
Tavily *tavily.Client
DefaultModel string
}
// ServeHTTP implements http.Handler.
func (h *AskHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req AskRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
sendError(w, "invalid JSON body", http.StatusBadRequest)
return
}
if req.Prompt == "" {
sendError(w, "prompt is required", http.StatusBadRequest)
return
}
if req.ResponseSchema == nil {
sendError(w, "response_schema is required", http.StatusBadRequest)
return
}
model := req.Model
if model == "" {
model = h.DefaultModel
}
if model == "" {
model = "llama3.1:70b"
}
var tools []ollama.Tool
if req.AllowWebSearch {
tools = []ollama.Tool{ollama.WebSearchTool()}
}
messages := []ollama.ChatMessage{}
if req.AllowWebSearch {
messages = append(messages, ollama.ChatMessage{
Role: "system",
Content: "When the user asks for current, recent, or real-time information (weather, prices, news, etc.), you MUST call the web_search tool with a suitable query. Do not answer from memory — use the tool and then summarize the results in your response.",
})
// Гарантированный запрос в Tavily: предпоиск по промпту пользователя, результат подмешивается в контекст.
searchQuery := req.Prompt
if len(searchQuery) > 200 {
searchQuery = searchQuery[:200]
}
log.Printf("tavily pre-search: query=%q", searchQuery)
preSearchResult, err := h.Tavily.Search(searchQuery)
if err != nil {
log.Printf("tavily pre-search error: %v", err)
preSearchResult = "search failed: " + err.Error()
} else {
log.Printf("tavily pre-search ok: %d bytes", len(preSearchResult))
}
messages = append(messages, ollama.ChatMessage{
Role: "system",
Content: "Relevant web search result for the user's question (use this to answer; if not enough, you may call web_search again):\n\n" + preSearchResult,
})
}
messages = append(messages, ollama.ChatMessage{
Role: "user", Content: req.Prompt,
})
const maxToolRounds = 20
for round := 0; round < maxToolRounds; round++ {
chatReq := &ollama.ChatRequest{
Model: model,
Messages: messages,
Stream: false,
Format: req.ResponseSchema,
Tools: tools,
}
resp, err := h.Ollama.Chat(chatReq)
if err != nil {
log.Printf("ollama chat error: %v", err)
sendError(w, "ollama request failed: "+err.Error(), http.StatusBadGateway)
return
}
messages = append(messages, resp.Message)
if n := len(resp.Message.ToolCalls); n > 0 {
log.Printf("ollama returned %d tool_calls", n)
}
if len(resp.Message.ToolCalls) == 0 {
// Final answer: message.content is JSON by schema
content := resp.Message.Content
if content == "" {
sendError(w, "empty response from model", http.StatusBadGateway)
return
}
// Return as { "result": <parsed JSON> } so client gets valid JSON
var raw json.RawMessage
if err := json.Unmarshal([]byte(content), &raw); err != nil {
// If not valid JSON, return as string inside result
raw = json.RawMessage(`"` + escapeJSONString(content) + `"`)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(AskResponse{Result: raw})
return
}
// Execute tool calls (web_search via Tavily)
for _, tc := range resp.Message.ToolCalls {
if tc.Function.Name != "web_search" {
messages = append(messages, ollama.ChatMessage{
Role: "tool", ToolName: tc.Function.Name, Content: "unknown tool",
})
continue
}
query := ollama.QueryFromToolCall(tc)
if query == "" {
// Некоторые модели подставляют в arguments не "query", а другие поля — используем промпт пользователя как поисковый запрос
query = req.Prompt
if len(query) > 200 {
query = query[:200]
}
log.Printf("web_search: query empty in tool_call, using user prompt (first 200 chars)")
}
log.Printf("tavily search: query=%q", query)
searchResult, err := h.Tavily.Search(query)
if err != nil {
log.Printf("tavily search error: %v", err)
searchResult = "search failed: " + err.Error()
} else {
log.Printf("tavily search ok: %d bytes", len(searchResult))
}
messages = append(messages, ollama.ChatMessage{
Role: "tool", ToolName: "web_search", Content: searchResult,
})
}
}
// Too many tool rounds
sendError(w, "too many tool-call rounds", http.StatusBadGateway)
}
func sendError(w http.ResponseWriter, msg string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(map[string]string{"error": msg})
}
func escapeJSONString(s string) string {
b, _ := json.Marshal(s)
return string(b[1 : len(b)-1])
}