| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- package llm
- import (
- "context"
- "encoding/json"
- "fmt"
- "strconv"
- "strings"
- "time"
- openai "github.com/sashabaranov/go-openai"
- "spider/internal/extractor"
- )
- // Client OpenAI 兼容的 LLM 客户端
- type Client struct {
- client *openai.Client
- model string
- timeout time.Duration
- }
- // New 创建客户端,支持任意 OpenAI 兼容接口
- // baseURL 为空时使用 OpenAI 官方接口
- func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
- cfg := openai.DefaultConfig(apiKey)
- if baseURL != "" {
- cfg.BaseURL = baseURL
- }
- return &Client{
- client: openai.NewClientWithConfig(cfg),
- model: model,
- timeout: timeout,
- }
- }
- // chat 内部封装:发送 system + user 消息,返回第一条回复文本
- func (c *Client) chat(ctx context.Context, system, user string) (string, error) {
- ctx, cancel := context.WithTimeout(ctx, c.timeout)
- defer cancel()
- resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
- Model: c.model,
- Messages: []openai.ChatCompletionMessage{
- {Role: openai.ChatMessageRoleSystem, Content: system},
- {Role: openai.ChatMessageRoleUser, Content: user},
- },
- })
- if err != nil {
- return "", fmt.Errorf("llm chat: %w", err)
- }
- if len(resp.Choices) == 0 {
- return "", fmt.Errorf("llm chat: empty response")
- }
- return strings.TrimSpace(resp.Choices[0].Message.Content), nil
- }
- // EvalChannelRelevance 评估 TG 频道是否与商户相关
- // 返回相关度评分 0-1,<0.5 认为不相关
- // 调用失败时返回 0.5 表示不确定
- func (c *Client) EvalChannelRelevance(ctx context.Context, name, about string, memberCount int) (float64, error) {
- const system = `你是商户识别专家。请判断以下 Telegram 频道是否与商户/卖家/服务提供商相关。
- 只关注是否有商品/服务在售。返回 0-1 的数字,1 表示高度相关,0 表示完全不相关。只返回数字,不要解释。`
- user := fmt.Sprintf("频道名:%s\n简介:%s\n成员数:%d", name, about, memberCount)
- text, err := c.chat(ctx, system, user)
- if err != nil {
- return 0.5, err
- }
- score, parseErr := strconv.ParseFloat(text, 64)
- if parseErr != nil {
- // 尝试从文本中提取第一个数字
- fields := strings.Fields(text)
- for _, f := range fields {
- if s, e := strconv.ParseFloat(f, 64); e == nil {
- return clamp01(s), nil
- }
- }
- return 0.5, fmt.Errorf("llm eval: cannot parse score from %q", text)
- }
- return clamp01(score), nil
- }
- // ParseMerchant 从消息文本中解析商户信息
- // 用于正则提取失败时的 fallback,或提取非标准格式如"加V:xxx"
- func (c *Client) ParseMerchant(ctx context.Context, message string) (*extractor.MerchantInfo, error) {
- const system = `你是信息提取专家。从以下文本中提取商户联系信息,返回 JSON 格式。
- 字段:merchant_name, tg_username(不含@), website, email, phone, industry, description
- 如果某字段没有信息则为空字符串。只返回 JSON,不要 markdown 代码块。`
- text, err := c.chat(ctx, system, message)
- if err != nil {
- return defaultMerchantInfo(), err
- }
- // 去除可能的 markdown 代码块包裹
- text = stripMarkdownCode(text)
- info := &extractor.MerchantInfo{}
- if jsonErr := json.Unmarshal([]byte(text), info); jsonErr != nil {
- return defaultMerchantInfo(), fmt.Errorf("llm parse merchant: json unmarshal: %w (raw: %s)", jsonErr, text)
- }
- return info, nil
- }
- // ClassifyIndustry 行业分类
- // 返回行业标签:机场/发卡/成人/电商/游戏/其他 等
- func (c *Client) ClassifyIndustry(ctx context.Context, name, about string) (string, error) {
- const system = `你是电商行业分类专家。根据频道信息,从以下类别中选择最匹配的一个:
- 机场、发卡、成人、电商、游戏充值、金融、软件工具、其他
- 只返回类别名称,不要解释。`
- user := fmt.Sprintf("名称:%s,简介:%s", name, about)
- text, err := c.chat(ctx, system, user)
- if err != nil {
- return "其他", err
- }
- return strings.TrimSpace(text), nil
- }
- // IsNavSite 判断 URL 是否是导航站/目录站
- // 返回 (是否是导航站, 置信度 0-1)
- func (c *Client) IsNavSite(ctx context.Context, url string) (bool, float64, error) {
- const system = `判断以下 URL 是否是导航站、目录站或聚合站(收录多个商家/服务的网站)。
- 返回 JSON: {"is_nav": true/false, "confidence": 0.0-1.0}`
- text, err := c.chat(ctx, system, url)
- if err != nil {
- return false, 0, err
- }
- text = stripMarkdownCode(text)
- var result struct {
- IsNav bool `json:"is_nav"`
- Confidence float64 `json:"confidence"`
- }
- if jsonErr := json.Unmarshal([]byte(text), &result); jsonErr != nil {
- return false, 0, fmt.Errorf("llm is_nav_site: json unmarshal: %w (raw: %s)", jsonErr, text)
- }
- return result.IsNav, clamp01(result.Confidence), nil
- }
- // stripMarkdownCode 去除 LLM 响应中可能包含的 markdown 代码块标记
- func stripMarkdownCode(s string) string {
- s = strings.TrimSpace(s)
- // 去除 ```json ... ``` 或 ``` ... ```
- if strings.HasPrefix(s, "```") {
- lines := strings.SplitN(s, "\n", 2)
- if len(lines) == 2 {
- s = lines[1]
- }
- if idx := strings.LastIndex(s, "```"); idx >= 0 {
- s = s[:idx]
- }
- s = strings.TrimSpace(s)
- }
- return s
- }
- // clamp01 将浮点数限制在 [0, 1] 范围内
- func clamp01(v float64) float64 {
- if v < 0 {
- return 0
- }
- if v > 1 {
- return 1
- }
- return v
- }
- // defaultMerchantInfo 返回空的 MerchantInfo(JSON 解析失败时的默认值)
- func defaultMerchantInfo() *extractor.MerchantInfo {
- return &extractor.MerchantInfo{}
- }
|