client.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package llm
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "time"
  8. openai "github.com/sashabaranov/go-openai"
  9. "spider/internal/extractor"
  10. )
  11. // Client is an OpenAI-compatible LLM client.
  12. // Used only for TG message merchant extraction (fallback when regex fails).
  13. type Client struct {
  14. client *openai.Client
  15. model string
  16. timeout time.Duration
  17. }
  18. // New creates a client. baseURL empty = OpenAI official endpoint.
  19. func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
  20. cfg := openai.DefaultConfig(apiKey)
  21. if baseURL != "" {
  22. cfg.BaseURL = baseURL
  23. }
  24. return &Client{
  25. client: openai.NewClientWithConfig(cfg),
  26. model: model,
  27. timeout: timeout,
  28. }
  29. }
  30. func (c *Client) chat(ctx context.Context, system, user string) (string, error) {
  31. ctx, cancel := context.WithTimeout(ctx, c.timeout)
  32. defer cancel()
  33. resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
  34. Model: c.model,
  35. Messages: []openai.ChatCompletionMessage{
  36. {Role: openai.ChatMessageRoleSystem, Content: system},
  37. {Role: openai.ChatMessageRoleUser, Content: user},
  38. },
  39. })
  40. if err != nil {
  41. return "", fmt.Errorf("llm chat: %w", err)
  42. }
  43. if len(resp.Choices) == 0 {
  44. return "", fmt.Errorf("llm chat: empty response")
  45. }
  46. return strings.TrimSpace(resp.Choices[0].Message.Content), nil
  47. }
  48. // ParseMerchant extracts merchant info from text.
  49. // Used as fallback when regex extraction fails on non-standard formats like "加V:xxx".
  50. func (c *Client) ParseMerchant(ctx context.Context, message string) (*extractor.MerchantInfo, error) {
  51. const system = `你是信息提取专家。从以下文本中提取商户联系信息,返回 JSON 格式。
  52. 字段:merchant_name, tg_username(不含@), website, email, phone, industry, description
  53. 如果某字段没有信息则为空字符串。只返回 JSON,不要 markdown 代码块。`
  54. text, err := c.chat(ctx, system, message)
  55. if err != nil {
  56. return &extractor.MerchantInfo{}, err
  57. }
  58. text = stripMarkdownCode(text)
  59. info := &extractor.MerchantInfo{}
  60. if jsonErr := json.Unmarshal([]byte(text), info); jsonErr != nil {
  61. return &extractor.MerchantInfo{}, fmt.Errorf("llm parse merchant: json unmarshal: %w (raw: %s)", jsonErr, text)
  62. }
  63. return info, nil
  64. }
  65. func stripMarkdownCode(s string) string {
  66. s = strings.TrimSpace(s)
  67. if strings.HasPrefix(s, "```") {
  68. lines := strings.SplitN(s, "\n", 2)
  69. if len(lines) == 2 {
  70. s = lines[1]
  71. }
  72. if idx := strings.LastIndex(s, "```"); idx >= 0 {
  73. s = s[:idx]
  74. }
  75. s = strings.TrimSpace(s)
  76. }
  77. return s
  78. }