client.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package llm
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strconv"
  7. "strings"
  8. "time"
  9. openai "github.com/sashabaranov/go-openai"
  10. "spider/internal/extractor"
  11. )
  12. // Client OpenAI 兼容的 LLM 客户端
  13. type Client struct {
  14. client *openai.Client
  15. model string
  16. timeout time.Duration
  17. }
  18. // New 创建客户端,支持任意 OpenAI 兼容接口
  19. // baseURL 为空时使用 OpenAI 官方接口
  20. func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
  21. cfg := openai.DefaultConfig(apiKey)
  22. if baseURL != "" {
  23. cfg.BaseURL = baseURL
  24. }
  25. return &Client{
  26. client: openai.NewClientWithConfig(cfg),
  27. model: model,
  28. timeout: timeout,
  29. }
  30. }
  31. // chat 内部封装:发送 system + user 消息,返回第一条回复文本
  32. func (c *Client) chat(ctx context.Context, system, user string) (string, error) {
  33. ctx, cancel := context.WithTimeout(ctx, c.timeout)
  34. defer cancel()
  35. resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
  36. Model: c.model,
  37. Messages: []openai.ChatCompletionMessage{
  38. {Role: openai.ChatMessageRoleSystem, Content: system},
  39. {Role: openai.ChatMessageRoleUser, Content: user},
  40. },
  41. })
  42. if err != nil {
  43. return "", fmt.Errorf("llm chat: %w", err)
  44. }
  45. if len(resp.Choices) == 0 {
  46. return "", fmt.Errorf("llm chat: empty response")
  47. }
  48. return strings.TrimSpace(resp.Choices[0].Message.Content), nil
  49. }
  50. // EvalChannelRelevance 评估 TG 频道是否与商户相关
  51. // 返回相关度评分 0-1,<0.5 认为不相关
  52. // 调用失败时返回 0.5 表示不确定
  53. func (c *Client) EvalChannelRelevance(ctx context.Context, name, about string, memberCount int) (float64, error) {
  54. const system = `你是商户识别专家。请判断以下 Telegram 频道是否与商户/卖家/服务提供商相关。
  55. 只关注是否有商品/服务在售。返回 0-1 的数字,1 表示高度相关,0 表示完全不相关。只返回数字,不要解释。`
  56. user := fmt.Sprintf("频道名:%s\n简介:%s\n成员数:%d", name, about, memberCount)
  57. text, err := c.chat(ctx, system, user)
  58. if err != nil {
  59. return 0.5, err
  60. }
  61. score, parseErr := strconv.ParseFloat(text, 64)
  62. if parseErr != nil {
  63. // 尝试从文本中提取第一个数字
  64. fields := strings.Fields(text)
  65. for _, f := range fields {
  66. if s, e := strconv.ParseFloat(f, 64); e == nil {
  67. return clamp01(s), nil
  68. }
  69. }
  70. return 0.5, fmt.Errorf("llm eval: cannot parse score from %q", text)
  71. }
  72. return clamp01(score), nil
  73. }
  74. // ParseMerchant 从消息文本中解析商户信息
  75. // 用于正则提取失败时的 fallback,或提取非标准格式如"加V:xxx"
  76. func (c *Client) ParseMerchant(ctx context.Context, message string) (*extractor.MerchantInfo, error) {
  77. const system = `你是信息提取专家。从以下文本中提取商户联系信息,返回 JSON 格式。
  78. 字段:merchant_name, tg_username(不含@), website, email, phone, industry, description
  79. 如果某字段没有信息则为空字符串。只返回 JSON,不要 markdown 代码块。`
  80. text, err := c.chat(ctx, system, message)
  81. if err != nil {
  82. return defaultMerchantInfo(), err
  83. }
  84. // 去除可能的 markdown 代码块包裹
  85. text = stripMarkdownCode(text)
  86. info := &extractor.MerchantInfo{}
  87. if jsonErr := json.Unmarshal([]byte(text), info); jsonErr != nil {
  88. return defaultMerchantInfo(), fmt.Errorf("llm parse merchant: json unmarshal: %w (raw: %s)", jsonErr, text)
  89. }
  90. return info, nil
  91. }
  92. // ClassifyIndustry 行业分类
  93. // 返回行业标签:机场/发卡/成人/电商/游戏/其他 等
  94. func (c *Client) ClassifyIndustry(ctx context.Context, name, about string) (string, error) {
  95. const system = `你是电商行业分类专家。根据频道信息,从以下类别中选择最匹配的一个:
  96. 机场、发卡、成人、电商、游戏充值、金融、软件工具、其他
  97. 只返回类别名称,不要解释。`
  98. user := fmt.Sprintf("名称:%s,简介:%s", name, about)
  99. text, err := c.chat(ctx, system, user)
  100. if err != nil {
  101. return "其他", err
  102. }
  103. return strings.TrimSpace(text), nil
  104. }
  105. // IsNavSite 判断 URL 是否是导航站/目录站
  106. // 返回 (是否是导航站, 置信度 0-1)
  107. func (c *Client) IsNavSite(ctx context.Context, url string) (bool, float64, error) {
  108. const system = `判断以下 URL 是否是导航站、目录站或聚合站(收录多个商家/服务的网站)。
  109. 返回 JSON: {"is_nav": true/false, "confidence": 0.0-1.0}`
  110. text, err := c.chat(ctx, system, url)
  111. if err != nil {
  112. return false, 0, err
  113. }
  114. text = stripMarkdownCode(text)
  115. var result struct {
  116. IsNav bool `json:"is_nav"`
  117. Confidence float64 `json:"confidence"`
  118. }
  119. if jsonErr := json.Unmarshal([]byte(text), &result); jsonErr != nil {
  120. return false, 0, fmt.Errorf("llm is_nav_site: json unmarshal: %w (raw: %s)", jsonErr, text)
  121. }
  122. return result.IsNav, clamp01(result.Confidence), nil
  123. }
  124. // stripMarkdownCode 去除 LLM 响应中可能包含的 markdown 代码块标记
  125. func stripMarkdownCode(s string) string {
  126. s = strings.TrimSpace(s)
  127. // 去除 ```json ... ``` 或 ``` ... ```
  128. if strings.HasPrefix(s, "```") {
  129. lines := strings.SplitN(s, "\n", 2)
  130. if len(lines) == 2 {
  131. s = lines[1]
  132. }
  133. if idx := strings.LastIndex(s, "```"); idx >= 0 {
  134. s = s[:idx]
  135. }
  136. s = strings.TrimSpace(s)
  137. }
  138. return s
  139. }
  140. // clamp01 将浮点数限制在 [0, 1] 范围内
  141. func clamp01(v float64) float64 {
  142. if v < 0 {
  143. return 0
  144. }
  145. if v > 1 {
  146. return 1
  147. }
  148. return v
  149. }
  150. // defaultMerchantInfo 返回空的 MerchantInfo(JSON 解析失败时的默认值)
  151. func defaultMerchantInfo() *extractor.MerchantInfo {
  152. return &extractor.MerchantInfo{}
  153. }