I have a instruct tune gemma3 4b model on VQA data...
생성일: 2025년 7월 31일
생성일: 2025년 7월 31일
I have a instruct tune gemma3 4b model on VQA dataset, for plant disease classification. I want to improve it's result on classification by leveraging it's reasoning ability in segmentation of disease pixels. How can I do this? Can I use SFT for that?
Mis-classified leaves are almost always those in which the model attends to the wrong region (background, healthy tissue, glare, etc.). Giving the model an auxiliary objective that forces it to isolate the symptomatic pixels tightens its attention and makes the class token more discriminative. Recent “reasoning-segmentation” papers show that even small multimodal LLMs gain 2–7 pp on the main task once they are taught to emit a mask token before (or together with) the textual answer — see LISA’s embedding-as-mask trick (CVF Open Access) and the SAM-plus-LLM pipeline in LLM-Seg (arXiv). For agriculture specifically, the SCOLD foundation model reports similar gains when it injects localisation cues during pre-training (arXiv).
Strategy | What you change | Pros | Cons |
---|---|---|---|
A. “One-stage” SFT with a <SEG> token (LISA style) | • Add a single extra token to Gemma’s text vocab. | ||
• At train-time, ground-truth mask M is down-sampled to the vision-encoder patch grid (e.g. 14×14), flattened to a bit-vector b, and projected with a small MLP to the same dim as the language embeddings. | |||
• Training sample looks like: |
text<image> Q: What disease? A: <SEG> <cls_token>
Loss = CE on mask bits + CE on class id. | • End-to-end; no extra encoder.
• Mask guides attention early in the decoding stack, so classification improves even if you discard the mask at test time. | • Requires ≥100 – 200 pixel-annotated images per disease to avoid over-fitting.
• Longer answers (mask bits) make generative latency higher unless you truncate. |
| B. Two-stage mask-selection (LLM-Seg style) | • Keep Gemma intact but freeze it.
• Run SAM (or similar) to generate K≈100 candidate masks.
• Feed Gemma the prompt:
<image+{mask_i}> "Which of the K regions shows symptoms of late blight?"
• Gemma chooses the index; the chosen mask is then passed to a tiny CNN head that predicts the class.
• Fine-tune Gemma with a small LoRA on the vision-language cross-attention blocks + a linear head on top of pooled vision features. | • No change to tokenizer.
• Works with synthetic masks (SAM) + minimal GT (only pick-the-best-mask supervision).
• Inference speed is good: SAM runs once; Gemma reasons over K embeddings (not pixels). | • Pipeline, not end-to-end.
• Need to keep SAM weights in memory. |
Data
Label acquisition
Model surgery
pythonfrom transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3b-vqa") tok = AutoTokenizer.from_pretrained("google/gemma-3b-vqa") tok.add_tokens(["<SEG>"]) # 1 new id model.resize_token_embeddings(len(tok)) # small MLP that turns the hidden state of <SEG> into a mask model.seg_head = torch.nn.Sequential( torch.nn.Linear(model.config.hidden_size, 196), torch.nn.Sigmoid())
pythonloss_cls = cross_entropy(logits_cls, label_cls) loss_seg = dice_loss(model.seg_head(h_seg), gt_mask) loss = loss_cls + 0.5 * loss_seg # weight can be tuned
Fine-tuning strategy
seg_head
; lr = 2e-4; batch = 32; 8-bit AdamW.Inference
Omit the mask: prompt Gemma just with the image and “What disease?”; the learned attention still focuses on the diseased pixels, lifting top-1 accuracy.
Yes. Both approaches above are standard supervised fine-tuning:
If you want to squeeze more out of the model after SFT:
Item | Done? |
---|---|
250–500 pixel-level masks per class (or reliable SAM pseudo-labels) | ☐ |
20 % of data kept for IoU + accuracy evaluation | ☐ |
Updated tokenizer + checkpoint with <SEG> | ☐ |
Mixed-precision & gradient-accum for 4 B params | ☐ |
Plan for serving: mask-less prompt or two-stage pipeline | ☐ |
Once these are in place, SFT should give you +3 – 5 pp class accuracy on PlantVillage-style test sets, and you’ll get the bonus ability to visualise exactly why the model predicted a class — a nice trust-building feature when deploying to growers.