| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- package llm
- import (
- "context"
- "encoding/json"
- "fmt"
- "strings"
- "time"
- openai "github.com/sashabaranov/go-openai"
- "spider/internal/extractor"
- )
- // Client is an OpenAI-compatible LLM client.
- // Used only for TG message merchant extraction (fallback when regex fails).
- type Client struct {
- client *openai.Client
- model string
- timeout time.Duration
- }
- // New creates a client. baseURL empty = OpenAI official endpoint.
- func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
- cfg := openai.DefaultConfig(apiKey)
- if baseURL != "" {
- cfg.BaseURL = baseURL
- }
- return &Client{
- client: openai.NewClientWithConfig(cfg),
- model: model,
- timeout: timeout,
- }
- }
- func (c *Client) chat(ctx context.Context, system, user string) (string, error) {
- ctx, cancel := context.WithTimeout(ctx, c.timeout)
- defer cancel()
- resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
- Model: c.model,
- Messages: []openai.ChatCompletionMessage{
- {Role: openai.ChatMessageRoleSystem, Content: system},
- {Role: openai.ChatMessageRoleUser, Content: user},
- },
- })
- if err != nil {
- return "", fmt.Errorf("llm chat: %w", err)
- }
- if len(resp.Choices) == 0 {
- return "", fmt.Errorf("llm chat: empty response")
- }
- return strings.TrimSpace(resp.Choices[0].Message.Content), nil
- }
- // ParseMerchant extracts merchant info from text.
- // Used as fallback when regex extraction fails on non-standard formats like "加V:xxx".
- func (c *Client) ParseMerchant(ctx context.Context, message string) (*extractor.MerchantInfo, error) {
- const system = `你是信息提取专家。从以下文本中提取商户联系信息,返回 JSON 格式。
- 字段:merchant_name, tg_username(不含@), website, email, phone, industry, description
- 如果某字段没有信息则为空字符串。只返回 JSON,不要 markdown 代码块。`
- text, err := c.chat(ctx, system, message)
- if err != nil {
- return &extractor.MerchantInfo{}, err
- }
- text = stripMarkdownCode(text)
- info := &extractor.MerchantInfo{}
- if jsonErr := json.Unmarshal([]byte(text), info); jsonErr != nil {
- return &extractor.MerchantInfo{}, fmt.Errorf("llm parse merchant: json unmarshal: %w (raw: %s)", jsonErr, text)
- }
- return info, nil
- }
- func stripMarkdownCode(s string) string {
- s = strings.TrimSpace(s)
- if strings.HasPrefix(s, "```") {
- lines := strings.SplitN(s, "\n", 2)
- if len(lines) == 2 {
- s = lines[1]
- }
- if idx := strings.LastIndex(s, "```"); idx >= 0 {
- s = s[:idx]
- }
- s = strings.TrimSpace(s)
- }
- return s
- }
|