177
play-life-llm/internal/handler/ask.go
Normal file
177
play-life-llm/internal/handler/ask.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"play-life-llm/internal/ollama"
|
||||
"play-life-llm/internal/tavily"
|
||||
)
|
||||
|
||||
// AskRequest is the POST /ask body.
|
||||
type AskRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
ResponseSchema interface{} `json:"response_schema"`
|
||||
Model string `json:"model,omitempty"`
|
||||
// AllowWebSearch: если true, в запрос к Ollama добавляются tools (web_search), и при вызове модели выполняется поиск через Tavily. Если false (по умолчанию), tools не передаются — модель просто возвращает JSON по схеме (подходит для простых запросов без интернета).
|
||||
AllowWebSearch bool `json:"allow_web_search,omitempty"`
|
||||
}
|
||||
|
||||
// AskResponse is the successful response (result is JSON by schema).
|
||||
type AskResponse struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
|
||||
// AskHandler handles POST /ask: prompt + response_schema -> LLM with optional web search, returns JSON.
|
||||
type AskHandler struct {
|
||||
Ollama *ollama.Client
|
||||
Tavily *tavily.Client
|
||||
DefaultModel string
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (h *AskHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req AskRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
sendError(w, "invalid JSON body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
sendError(w, "prompt is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.ResponseSchema == nil {
|
||||
sendError(w, "response_schema is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = h.DefaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama3.1:70b"
|
||||
}
|
||||
|
||||
var tools []ollama.Tool
|
||||
if req.AllowWebSearch {
|
||||
tools = []ollama.Tool{ollama.WebSearchTool()}
|
||||
}
|
||||
messages := []ollama.ChatMessage{}
|
||||
if req.AllowWebSearch {
|
||||
messages = append(messages, ollama.ChatMessage{
|
||||
Role: "system",
|
||||
Content: "When the user asks for current, recent, or real-time information (weather, prices, news, etc.), you MUST call the web_search tool with a suitable query. Do not answer from memory — use the tool and then summarize the results in your response.",
|
||||
})
|
||||
// Гарантированный запрос в Tavily: предпоиск по промпту пользователя, результат подмешивается в контекст.
|
||||
searchQuery := req.Prompt
|
||||
if len(searchQuery) > 200 {
|
||||
searchQuery = searchQuery[:200]
|
||||
}
|
||||
log.Printf("tavily pre-search: query=%q", searchQuery)
|
||||
preSearchResult, err := h.Tavily.Search(searchQuery)
|
||||
if err != nil {
|
||||
log.Printf("tavily pre-search error: %v", err)
|
||||
preSearchResult = "search failed: " + err.Error()
|
||||
} else {
|
||||
log.Printf("tavily pre-search ok: %d bytes", len(preSearchResult))
|
||||
}
|
||||
messages = append(messages, ollama.ChatMessage{
|
||||
Role: "system",
|
||||
Content: "Relevant web search result for the user's question (use this to answer; if not enough, you may call web_search again):\n\n" + preSearchResult,
|
||||
})
|
||||
}
|
||||
messages = append(messages, ollama.ChatMessage{
|
||||
Role: "user", Content: req.Prompt,
|
||||
})
|
||||
|
||||
const maxToolRounds = 20
|
||||
for round := 0; round < maxToolRounds; round++ {
|
||||
chatReq := &ollama.ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
Format: req.ResponseSchema,
|
||||
Tools: tools,
|
||||
}
|
||||
resp, err := h.Ollama.Chat(chatReq)
|
||||
if err != nil {
|
||||
log.Printf("ollama chat error: %v", err)
|
||||
sendError(w, "ollama request failed: "+err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
messages = append(messages, resp.Message)
|
||||
|
||||
if n := len(resp.Message.ToolCalls); n > 0 {
|
||||
log.Printf("ollama returned %d tool_calls", n)
|
||||
}
|
||||
if len(resp.Message.ToolCalls) == 0 {
|
||||
// Final answer: message.content is JSON by schema
|
||||
content := resp.Message.Content
|
||||
if content == "" {
|
||||
sendError(w, "empty response from model", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
// Return as { "result": <parsed JSON> } so client gets valid JSON
|
||||
var raw json.RawMessage
|
||||
if err := json.Unmarshal([]byte(content), &raw); err != nil {
|
||||
// If not valid JSON, return as string inside result
|
||||
raw = json.RawMessage(`"` + escapeJSONString(content) + `"`)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(AskResponse{Result: raw})
|
||||
return
|
||||
}
|
||||
|
||||
// Execute tool calls (web_search via Tavily)
|
||||
for _, tc := range resp.Message.ToolCalls {
|
||||
if tc.Function.Name != "web_search" {
|
||||
messages = append(messages, ollama.ChatMessage{
|
||||
Role: "tool", ToolName: tc.Function.Name, Content: "unknown tool",
|
||||
})
|
||||
continue
|
||||
}
|
||||
query := ollama.QueryFromToolCall(tc)
|
||||
if query == "" {
|
||||
// Некоторые модели подставляют в arguments не "query", а другие поля — используем промпт пользователя как поисковый запрос
|
||||
query = req.Prompt
|
||||
if len(query) > 200 {
|
||||
query = query[:200]
|
||||
}
|
||||
log.Printf("web_search: query empty in tool_call, using user prompt (first 200 chars)")
|
||||
}
|
||||
log.Printf("tavily search: query=%q", query)
|
||||
searchResult, err := h.Tavily.Search(query)
|
||||
if err != nil {
|
||||
log.Printf("tavily search error: %v", err)
|
||||
searchResult = "search failed: " + err.Error()
|
||||
} else {
|
||||
log.Printf("tavily search ok: %d bytes", len(searchResult))
|
||||
}
|
||||
messages = append(messages, ollama.ChatMessage{
|
||||
Role: "tool", ToolName: "web_search", Content: searchResult,
|
||||
})
|
||||
}
|
||||
}
|
||||
// Too many tool rounds
|
||||
sendError(w, "too many tool-call rounds", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func sendError(w http.ResponseWriter, msg string, code int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": msg})
|
||||
}
|
||||
|
||||
func escapeJSONString(s string) string {
|
||||
b, _ := json.Marshal(s)
|
||||
return string(b[1 : len(b)-1])
|
||||
}
|
||||
17
play-life-llm/internal/handler/health.go
Normal file
17
play-life-llm/internal/handler/health.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Health returns 200 with {"status": "ok"} for Docker healthcheck.
|
||||
func Health(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
148
play-life-llm/internal/ollama/client.go
Normal file
148
play-life-llm/internal/ollama/client.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultTimeout = 10 * time.Minute
|
||||
|
||||
// Client calls Ollama /api/chat.
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates an Ollama client. baseURL is e.g. "http://localhost:11434".
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
BaseURL: baseURL,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ChatRequest matches Ollama POST /api/chat body.
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Format interface{} `json:"format,omitempty"` // "json" or JSON schema object
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage is one message in the conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "user", "assistant", "system", "tool"
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"` // for role "tool"
|
||||
}
|
||||
|
||||
// Tool defines a function the model may call.
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunc `json:"function"`
|
||||
}
|
||||
|
||||
// ToolFunc describes the function.
|
||||
type ToolFunc struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
// ToolCall is a model request to call a tool.
|
||||
type ToolCall struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolCallFn `json:"function"`
|
||||
}
|
||||
|
||||
// ToolCallFn holds name and arguments.
|
||||
// Arguments may come from Ollama as a JSON object or as a JSON string.
|
||||
type ToolCallFn struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"` // object or string
|
||||
}
|
||||
|
||||
// QueryFromToolCall returns the "query" argument from a web_search tool call.
|
||||
// Ollama may send arguments as a map or as a JSON string.
|
||||
func QueryFromToolCall(tc ToolCall) string {
|
||||
switch v := tc.Function.Arguments.(type) {
|
||||
case map[string]interface{}:
|
||||
if q, _ := v["query"].(string); q != "" {
|
||||
return q
|
||||
}
|
||||
case string:
|
||||
var m map[string]interface{}
|
||||
if json.Unmarshal([]byte(v), &m) == nil {
|
||||
if q, _ := m["query"].(string); q != "" {
|
||||
return q
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ChatResponse is the Ollama /api/chat response.
|
||||
type ChatResponse struct {
|
||||
Message ChatMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// Chat sends a chat request and returns the response.
|
||||
func (c *Client) Chat(req *ChatRequest) (*ChatResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
url := c.BaseURL + "/api/chat"
|
||||
httpReq, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HTTPClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama returned %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
|
||||
var out ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// WebSearchTool returns the tool definition for web_search (Tavily).
|
||||
func WebSearchTool() Tool {
|
||||
return Tool{
|
||||
Type: "function",
|
||||
Function: ToolFunc{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for current information. Use when you need up-to-date or factual information from the internet.",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
35
play-life-llm/internal/server/server.go
Normal file
35
play-life-llm/internal/server/server.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"play-life-llm/internal/handler"
|
||||
"play-life-llm/internal/ollama"
|
||||
"play-life-llm/internal/tavily"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Config holds server and client configuration.
|
||||
type Config struct {
|
||||
OllamaHost string
|
||||
TavilyAPIKey string
|
||||
DefaultModel string
|
||||
}
|
||||
|
||||
// NewRouter returns an HTTP router with /health and /ask registered.
|
||||
func NewRouter(cfg Config) http.Handler {
|
||||
ollamaClient := ollama.NewClient(cfg.OllamaHost)
|
||||
tavilyClient := tavily.NewClient(cfg.TavilyAPIKey)
|
||||
|
||||
askHandler := &handler.AskHandler{
|
||||
Ollama: ollamaClient,
|
||||
Tavily: tavilyClient,
|
||||
DefaultModel: cfg.DefaultModel,
|
||||
}
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/health", handler.Health).Methods(http.MethodGet)
|
||||
r.Handle("/ask", askHandler).Methods(http.MethodPost)
|
||||
return r
|
||||
}
|
||||
104
play-life-llm/internal/tavily/client.go
Normal file
104
play-life-llm/internal/tavily/client.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package tavily
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
baseURL = "https://api.tavily.com"
|
||||
searchPath = "/search"
|
||||
timeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Client calls Tavily Search API.
|
||||
type Client struct {
|
||||
APIKey string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a Tavily client. apiKey is required for search.
|
||||
func NewClient(apiKey string) *Client {
|
||||
return &Client{
|
||||
APIKey: apiKey,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SearchRequest is the POST body for /search.
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
SearchDepth string `json:"search_depth,omitempty"` // basic, advanced, etc.
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// SearchResult is one result item.
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SearchResponse is the Tavily search response.
|
||||
type SearchResponse struct {
|
||||
Query string `json:"query"`
|
||||
Answer string `json:"answer,omitempty"`
|
||||
Results []SearchResult `json:"results"`
|
||||
}
|
||||
|
||||
// Search runs a web search and returns a single text suitable for passing to Ollama as tool result.
|
||||
func (c *Client) Search(query string) (string, error) {
|
||||
if c.APIKey == "" {
|
||||
return "", fmt.Errorf("tavily: API key not set")
|
||||
}
|
||||
body, err := json.Marshal(SearchRequest{
|
||||
Query: query,
|
||||
MaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := baseURL + searchPath
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("tavily returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var out SearchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
// Build a single text for the model: prefer answer if present, else concatenate results.
|
||||
if out.Answer != "" {
|
||||
return out.Answer, nil
|
||||
}
|
||||
var b bytes.Buffer
|
||||
for i, r := range out.Results {
|
||||
if i > 0 {
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString(r.Title)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(r.Content)
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user