package ai import ( "bytes" "crypto/rand" "encoding/base64" "encoding/json" "fmt" "io" "log" "net/http" "os" "strings" "time" "pdf-form-api/annotator" "pdf-form-api/models" "pdf-form-api/renderer" ) type Client struct { BaseURL string APIKey string Model string timeout time.Duration http *http.Client } func NewClient(baseURL, apiKey, model string) *Client { return &Client{ BaseURL: strings.TrimRight(baseURL, "/"), APIKey: apiKey, Model: model, timeout: 120 * time.Second, http: &http.Client{ Timeout: 120 * time.Second, }, } } type FieldResult struct { Name string Question string } type VisionResult struct { Description string Fields []VisionFieldEntry LabelMap map[string]int64 // label -> field ID } type VisionFieldEntry struct { Label string `json:"label"` Question string `json:"question"` ValueGroup string `json:"value_group"` WizardPage int `json:"wizard_page"` } func (c *Client) GenerateQuestions(pdfPath string, fields []models.FormField) ([]FieldResult, error) { pages, err := renderer.GetPageInfo(pdfPath) if err != nil { return nil, fmt.Errorf("getting page info: %w", err) } grouped := annotator.BuildFieldContexts(fields, pages) var results []FieldResult for pageNum := 1; pageNum <= len(pages); pageNum++ { ctxs := grouped[pageNum] if len(ctxs) == 0 { continue } prompt := annotator.PagePrompt(pageNum, ctxs) pageResults, err := c.askAPI(prompt) if err != nil { return nil, fmt.Errorf("generating questions for page %d: %w", pageNum, err) } results = append(results, pageResults...) } return results, nil } func (c *Client) GenerateQuestionsWithVision(pdfPath string, fields []models.FormField) (*VisionResult, error) { return c.GenerateQuestionsWithVisionAndImages(pdfPath, fields, "", 0) } func (c *Client) GenerateQuestionsWithVisionAndImages(pdfPath string, fields []models.FormField, imageStoreDir string, pdfID int64) (*VisionResult, error) { pages, err := renderer.GetPageInfo(pdfPath) if err != nil { return nil, fmt.Errorf("getting page info: %w", err) } tmpDir, err := os.MkdirTemp("", "pdf-vision-*") if err != nil { return nil, fmt.Errorf("creating temp dir: %w", err) } defer os.RemoveAll(tmpDir) imagePaths, err := renderer.RenderPages(pdfPath, tmpDir, 150) if err != nil { return nil, fmt.Errorf("rendering pages: %w", err) } // Generate non-sequential random labels for each field labelMap := generateLabels(fields) fieldsByPage := groupFieldsByPage(fields) var annotatedPaths []string for i, imgPath := range imagePaths { pageNum := i + 1 pageInfo := pages[i] pageFields := fieldsByPage[pageNum] annotatedPath := tmpDir + "/annotated_page_" + fmt.Sprintf("%d", pageNum) + ".png" if len(pageFields) > 0 { imgW, imgH := renderer.ComputeImageDimensions(pageInfo.Width, pageInfo.Height, 150, 1024) if err := renderer.AnnotateImage(imgPath, annotatedPath, pageFields, labelMap, pageInfo.Width, pageInfo.Height, imgW, imgH); err != nil { return nil, fmt.Errorf("annotating page %d: %w", pageNum, err) } } else { // No fields on this page - use clean image as-is if err := copyFile(imgPath, annotatedPath); err != nil { return nil, fmt.Errorf("copying page %d image: %w", pageNum, err) } } if imageStoreDir != "" { baseDir := fmt.Sprintf("%s/%d", imageStoreDir, pdfID) pageDir := fmt.Sprintf("%s/pages/%d", baseDir, pageNum) if err := os.MkdirAll(pageDir, 0o755); err != nil { log.Printf("[ai] failed to create image dir %s: %v", pageDir, err) } else { dstClean := fmt.Sprintf("%s/page_%d.png", pageDir, pageNum) dstAnnotated := fmt.Sprintf("%s/page_%d_annotated.png", pageDir, pageNum) if err := copyFile(imgPath, dstClean); err != nil { log.Printf("[ai] failed to save clean page %d: %v", pageNum, err) } if err := copyFile(annotatedPath, dstAnnotated); err != nil { log.Printf("[ai] failed to save annotated page %d: %v", pageNum, err) } } } annotatedPaths = append(annotatedPaths, annotatedPath) } visionFields := buildVisionFieldMeta(fields, labelMap) prompt := annotator.BuildVisionPrompt(visionFields, len(pages)) result, err := c.askVisionAPI(prompt, annotatedPaths...) if err != nil { return nil, err } result.LabelMap = buildReverseLabelMap(labelMap) return result, nil } func copyFile(src, dst string) error { data, err := os.ReadFile(src) if err != nil { return err } return os.WriteFile(dst, data, 0o644) } func groupFieldsByPage(fields []models.FormField) map[int][]models.FormField { m := make(map[int][]models.FormField) for _, f := range fields { page := f.Page if page == 0 { page = 1 } m[page] = append(m[page], f) } return m } func generateLabels(fields []models.FormField) map[int64]string { labels := make(map[int64]string, len(fields)) for _, f := range fields { labels[f.ID] = randHex4() } return labels } func randHex4() string { b := make([]byte, 4) _, _ = rand.Read(b) return strings.ToUpper(fmt.Sprintf("%02x%02x%02x%02x", b[0], b[1], b[2], b[3])) } func buildReverseLabelMap(labelMap map[int64]string) map[string]int64 { m := make(map[string]int64, len(labelMap)) for id, label := range labelMap { m[label] = id } return m } func buildVisionFieldMeta(fields []models.FormField, labelMap map[int64]string) []annotator.VisionFieldMeta { var meta []annotator.VisionFieldMeta for _, f := range fields { meta = append(meta, annotator.VisionFieldMeta{ ID: f.ID, Name: f.Name, FieldType: string(f.Type), Choices: f.Choices, DefaultValue: f.DefaultVal, Label: labelMap[f.ID], Page: f.Page, }) } return meta } func (c *Client) askVisionAPI(prompt string, imagePaths ...string) (*VisionResult, error) { contentParts := []contentPart{ {Type: "text", Text: prompt}, } for _, imagePath := range imagePaths { imageData, err := os.ReadFile(imagePath) if err != nil { return nil, fmt.Errorf("reading image %s: %w", imagePath, err) } base64Img := base64.StdEncoding.EncodeToString(imageData) dataURI := "data:image/png;base64," + base64Img contentParts = append(contentParts, contentPart{Type: "image_url", ImageURL: &imageURL{URL: dataURI}}) } reqBody := chatCompletionRequest{ Model: c.Model, Messages: []chatMessage{ {Role: "system", Content: "You are a helpful form wizard assistant. You analyze PDF form images and generate questions for form fields."}, {Role: "user", ContentParts: contentParts}, }, MaxTokens: 4096, Temperature: 0.3, } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("marshaling request: %w", err) } httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("creating request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.APIKey) resp, err := c.http.Do(httpReq) if err != nil { return nil, fmt.Errorf("http request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("api error %d: %s", resp.StatusCode, truncate(string(body), 500)) } var completion chatCompletionResponse if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { return nil, fmt.Errorf("decoding response: %w", err) } if len(completion.Choices) == 0 { return nil, fmt.Errorf("no choices in response") } return parseVisionResponse(completion.Choices[0].Message.Content) } func (c *Client) askAPI(prompt string) ([]FieldResult, error) { reqBody := chatCompletionRequest{ Model: c.Model, Messages: []chatMessage{ {Role: "system", Content: "You are a helpful form wizard assistant. Generate clear, natural-language questions for form fields."}, {Role: "user", Content: prompt}, }, MaxTokens: 8192, Temperature: 0.3, } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("marshaling request: %w", err) } httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("creating request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.APIKey) resp, err := c.http.Do(httpReq) if err != nil { return nil, fmt.Errorf("http request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("api error %d: %s", resp.StatusCode, truncate(string(body), 500)) } var completion chatCompletionResponse if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { return nil, fmt.Errorf("decoding response: %w", err) } if len(completion.Choices) == 0 { return nil, fmt.Errorf("no choices in response") } return parseResponse(completion.Choices[0].Message.Content) } func parseResponse(content string) ([]FieldResult, error) { content = strings.TrimSpace(content) startIdx := strings.Index(content, "[") endIdx := strings.LastIndex(content, "]") if startIdx < 0 || endIdx < 0 || endIdx <= startIdx { return nil, fmt.Errorf("could not find JSON array in response: %s", truncate(content, 200)) } content = content[startIdx : endIdx+1] var parsed []struct { Name string `json:"name"` Question string `json:"question"` } if err := json.Unmarshal([]byte(content), &parsed); err != nil { return nil, fmt.Errorf("parsing AI response JSON: %w. Raw: %s", err, truncate(content, 300)) } var results []FieldResult for _, p := range parsed { if p.Name == "" { continue } if p.Question == "" { p.Question = p.Name } results = append(results, FieldResult{ Name: strings.TrimSpace(p.Name), Question: strings.TrimSpace(p.Question), }) } return results, nil } func parseVisionResponse(content string) (*VisionResult, error) { content = strings.TrimSpace(content) startIdx := strings.Index(content, "{") endIdx := strings.LastIndex(content, "}") if startIdx < 0 || endIdx < 0 || endIdx <= startIdx { return nil, fmt.Errorf("could not find JSON object in response: %s", truncate(content, 200)) } content = content[startIdx : endIdx+1] var parsed struct { Description string `json:"description"` Fields []VisionFieldEntry `json:"fields"` } if err := json.Unmarshal([]byte(content), &parsed); err != nil { return nil, fmt.Errorf("parsing AI response JSON: %w. Raw: %s", err, truncate(content, 300)) } for i := range parsed.Fields { parsed.Fields[i].Label = strings.TrimSpace(parsed.Fields[i].Label) parsed.Fields[i].Question = strings.TrimSpace(parsed.Fields[i].Question) parsed.Fields[i].ValueGroup = strings.TrimSpace(parsed.Fields[i].ValueGroup) if parsed.Fields[i].Question == "" { parsed.Fields[i].Question = parsed.Fields[i].Label } } return &VisionResult{ Description: strings.TrimSpace(parsed.Description), Fields: parsed.Fields, }, nil } func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } type chatCompletionRequest struct { Model string `json:"model"` Messages []chatMessage `json:"messages"` MaxTokens int `json:"max_completion_tokens"` Temperature float64 `json:"temperature"` } type chatMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` ContentParts []contentPart `json:"-"` } func (m chatMessage) MarshalJSON() ([]byte, error) { type alias chatMessage if len(m.ContentParts) > 0 { return json.Marshal(&struct { Role string `json:"role"` Content []contentPart `json:"content"` }{ Role: m.Role, Content: m.ContentParts, }) } return json.Marshal(&alias{ Role: m.Role, Content: m.Content, }) } type contentPart struct { Type string `json:"type"` Text string `json:"text,omitempty"` ImageURL *imageURL `json:"image_url,omitempty"` } type imageURL struct { URL string `json:"url"` } type chatCompletionResponse struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` }