package llm import ( "context" "encoding/json" "fmt" "strconv" "strings" "time" openai "github.com/sashabaranov/go-openai" "spider/internal/extractor" ) // Client OpenAI 兼容的 LLM 客户端 type Client struct { client *openai.Client model string timeout time.Duration } // New 创建客户端,支持任意 OpenAI 兼容接口 // baseURL 为空时使用 OpenAI 官方接口 func New(baseURL, apiKey, model string, timeout time.Duration) *Client { cfg := openai.DefaultConfig(apiKey) if baseURL != "" { cfg.BaseURL = baseURL } return &Client{ client: openai.NewClientWithConfig(cfg), model: model, timeout: timeout, } } // chat 内部封装:发送 system + user 消息,返回第一条回复文本 func (c *Client) chat(ctx context.Context, system, user string) (string, error) { ctx, cancel := context.WithTimeout(ctx, c.timeout) defer cancel() resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, Messages: []openai.ChatCompletionMessage{ {Role: openai.ChatMessageRoleSystem, Content: system}, {Role: openai.ChatMessageRoleUser, Content: user}, }, }) if err != nil { return "", fmt.Errorf("llm chat: %w", err) } if len(resp.Choices) == 0 { return "", fmt.Errorf("llm chat: empty response") } return strings.TrimSpace(resp.Choices[0].Message.Content), nil } // EvalChannelRelevance 评估 TG 频道是否与商户相关 // 返回相关度评分 0-1,<0.5 认为不相关 // 调用失败时返回 0.5 表示不确定 func (c *Client) EvalChannelRelevance(ctx context.Context, name, about string, memberCount int) (float64, error) { const system = `你是商户识别专家。请判断以下 Telegram 频道是否与商户/卖家/服务提供商相关。 只关注是否有商品/服务在售。返回 0-1 的数字,1 表示高度相关,0 表示完全不相关。只返回数字,不要解释。` user := fmt.Sprintf("频道名:%s\n简介:%s\n成员数:%d", name, about, memberCount) text, err := c.chat(ctx, system, user) if err != nil { return 0.5, err } score, parseErr := strconv.ParseFloat(text, 64) if parseErr != nil { // 尝试从文本中提取第一个数字 fields := strings.Fields(text) for _, f := range fields { if s, e := strconv.ParseFloat(f, 64); e == nil { return clamp01(s), nil } } return 0.5, fmt.Errorf("llm eval: cannot parse score from %q", text) } return clamp01(score), nil } // ParseMerchant 从消息文本中解析商户信息 // 用于正则提取失败时的 fallback,或提取非标准格式如"加V:xxx" func (c *Client) ParseMerchant(ctx context.Context, message string) (*extractor.MerchantInfo, error) { const system = `你是信息提取专家。从以下文本中提取商户联系信息,返回 JSON 格式。 字段:merchant_name, tg_username(不含@), website, email, phone, industry, description 如果某字段没有信息则为空字符串。只返回 JSON,不要 markdown 代码块。` text, err := c.chat(ctx, system, message) if err != nil { return defaultMerchantInfo(), err } // 去除可能的 markdown 代码块包裹 text = stripMarkdownCode(text) info := &extractor.MerchantInfo{} if jsonErr := json.Unmarshal([]byte(text), info); jsonErr != nil { return defaultMerchantInfo(), fmt.Errorf("llm parse merchant: json unmarshal: %w (raw: %s)", jsonErr, text) } return info, nil } // ClassifyIndustry 行业分类 // 返回行业标签:机场/发卡/成人/电商/游戏/其他 等 func (c *Client) ClassifyIndustry(ctx context.Context, name, about string) (string, error) { const system = `你是电商行业分类专家。根据频道信息,从以下类别中选择最匹配的一个: 机场、发卡、成人、电商、游戏充值、金融、软件工具、其他 只返回类别名称,不要解释。` user := fmt.Sprintf("名称:%s,简介:%s", name, about) text, err := c.chat(ctx, system, user) if err != nil { return "其他", err } return strings.TrimSpace(text), nil } // IsNavSite 判断 URL 是否是导航站/目录站 // 返回 (是否是导航站, 置信度 0-1) func (c *Client) IsNavSite(ctx context.Context, url string) (bool, float64, error) { const system = `判断以下 URL 是否是导航站、目录站或聚合站(收录多个商家/服务的网站)。 返回 JSON: {"is_nav": true/false, "confidence": 0.0-1.0}` text, err := c.chat(ctx, system, url) if err != nil { return false, 0, err } text = stripMarkdownCode(text) var result struct { IsNav bool `json:"is_nav"` Confidence float64 `json:"confidence"` } if jsonErr := json.Unmarshal([]byte(text), &result); jsonErr != nil { return false, 0, fmt.Errorf("llm is_nav_site: json unmarshal: %w (raw: %s)", jsonErr, text) } return result.IsNav, clamp01(result.Confidence), nil } // stripMarkdownCode 去除 LLM 响应中可能包含的 markdown 代码块标记 func stripMarkdownCode(s string) string { s = strings.TrimSpace(s) // 去除 ```json ... ``` 或 ``` ... ``` if strings.HasPrefix(s, "```") { lines := strings.SplitN(s, "\n", 2) if len(lines) == 2 { s = lines[1] } if idx := strings.LastIndex(s, "```"); idx >= 0 { s = s[:idx] } s = strings.TrimSpace(s) } return s } // clamp01 将浮点数限制在 [0, 1] 范围内 func clamp01(v float64) float64 { if v < 0 { return 0 } if v > 1 { return 1 } return v } // defaultMerchantInfo 返回空的 MerchantInfo(JSON 解析失败时的默认值) func defaultMerchantInfo() *extractor.MerchantInfo { return &extractor.MerchantInfo{} }