Fine-tuning LLM với Unsloth và QLoRA: Huấn luyện Llama 3 trên GPU dân dụng

AI tutorial - IT technology blog
AI tutorial - IT technology blog

Rào cản VRAM: Tại sao huấn luyện cục bộ thường thất bại

Lần đầu tiên tôi thử fine-tune Llama 2 trên máy trạm cá nhân là một thảm họa. Tôi có một chiếc RTX 3090 với 24GB VRAM, con số có vẻ là dư dả vào thời điểm đó. Tuy nhiên, chỉ vài giây sau khi nhấn ‘run’, terminal đã bị sập với lỗi CUDA Out of Memory (OOM) kinh điển. Ngay cả khi sử dụng batch size cực nhỏ và quantization 4-bit, chi phí bộ nhớ từ gradients và optimizer states vẫn vượt quá khả năng của phần cứng.

Cái gọi là ‘rào cản VRAM’ này là một trải nghiệm mà bất kỳ lập trình viên nào cũng phải trải qua khi chuyển từ việc gọi API đơn giản sang tùy chỉnh kiến trúc mô hình. Mặc dù bạn có thể muốn một mô hình hiểu được các thuật ngữ chuyên môn về luật pháp hoặc y tế, nhưng yêu cầu về phần cứng thường khiến chúng ta nản lòng. Việc thuê một chiếc H100 với giá 4,00 USD mỗi giờ sẽ nhanh chóng ngốn sạch ngân sách trong giai đoạn thử nghiệm.

Hiểu về nút thắt cổ chai bộ nhớ

Để giải quyết các vấn đề về bộ nhớ, chúng ta cần xem xét điều gì thực sự chiếm dụng GPU của bạn. Việc tải Llama 3 8B ở độ chính xác 16-bit (FP16) yêu cầu khoảng 15GB VRAM chỉ dành cho trọng số (weights). Quá trình huấn luyện còn đòi hỏi khắt khe hơn nhiều. Trong quá trình lan truyền ngược (backward pass), hệ thống của bạn phải lưu trữ đồng thời nhiều loại dữ liệu:

  • Model Weights: Các tham số cốt lõi của mạng nơ-ron.
  • Optimizer States: Dữ liệu được các thuật toán như AdamW sử dụng để theo dõi các cập nhật trọng số.
  • Gradients: Hướng và mức độ thay đổi được tính toán cho mỗi trọng số.
  • Activations: Các tính toán trung gian được lưu lại riêng cho quá trình lan truyền ngược.

Trong một thiết lập điển hình, optimizer states và gradients có thể tiêu tốn bộ nhớ gấp bốn lần so với chính mô hình đó. Ngay cả với LoRA (Low-Rank Adaptation), bộ nhớ dành cho activation vẫn tăng mạnh theo độ dài chuỗi (sequence length). Thử huấn luyện trên cửa sổ ngữ cảnh (context window) 4.096, và 24GB VRAM của bạn có khả năng sẽ biến mất trước khi epoch đầu tiên kết thúc.

Đánh giá các chiến lược Fine-tuning

Thử nghiệm các thư viện tối ưu hóa khác nhau cho thấy ba hướng đi riêng biệt cho các nhà phát triển:

1. Full Fine-tuning

Phương pháp này cập nhật mọi tham số trong mô hình. Đây là cách tiếp cận toàn diện nhất nhưng đòi hỏi phần cứng khổng lồ. Bạn thường cần một cụm 8x A100 để xử lý một mô hình 8B, khiến nó nằm ngoài tầm với của hầu hết các đội ngũ nhỏ.

2. LoRA và QLoRA tiêu chuẩn

LoRA đóng băng các trọng số cơ bản và thêm các lớp ‘adapter’ nhỏ có thể huấn luyện được. QLoRA cải thiện điều này bằng cách lượng tử hóa mô hình cơ sở xuống 4-bit. Mặc dù điều này giúp việc huấn luyện trên phần cứng dân dụng trở nên khả thi, nhưng bản thực thi PEFT tiêu chuẩn của Hugging Face thường chậm. Nó dựa trên các CUDA kernel chung chung không được tối ưu hóa cho các phép toán toán học cụ thể của LoRA.

3. Lợi thế từ Unsloth

Unsloth thay đổi cuộc chơi bằng cách viết lại các kernel toán học cốt lõi—đặc biệt là các lớp attention và MLP—sử dụng Triton. Nó tối ưu hóa quá trình backpropagation một cách thủ công và cắt giảm đáng kể chi phí bộ nhớ. Kết quả là tốc độ huấn luyện nhanh hơn gấp 2 lần và sử dụng ít bộ nhớ hơn 70% so với QLoRA tiêu chuẩn. Đây là cách hiệu quả nhất để triển khai AI tùy chỉnh mà không tốn kém chi phí đám mây khổng lồ.

Quy trình thực tế: Unsloth + QLoRA

Kết hợp các kernel đã tối ưu hóa của Unsloth với quantization 4-bit hiện là tiêu chuẩn vàng về hiệu quả. Tôi đã sử dụng thiết lập này để xử lý 50.000 lệnh (instructions) trong chưa đầy một giờ trên một GPU tầm trung duy nhất. Dưới đây là cách thực hiện.

Bước 1: Cấu hình môi trường

Bắt đầu với một môi trường Linux sạch và driver NVIDIA đã cập nhật. Tôi khuyên bạn nên sử dụng Conda để quản lý các thư viện phụ thuộc và tránh xung đột phiên bản.

conda create --name unsloth_env python=3.10 -y
conda activate unsloth_env

# Cài đặt Unsloth và các thư viện phụ thuộc thiết yếu
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

Bước 2: Tải Llama 3

Chúng ta sử dụng loader của Unsloth thay vì class tiêu chuẩn của Hugging Face. Việc chuyển đổi này sẽ tự động kích hoạt các Triton kernel đã tối ưu hóa để đạt hiệu suất tốt hơn.

from unsloth import FastLanguageModel
import torch

max_seq_length = 2048 
dtype = None # Tự động phát hiện dựa trên GPU của bạn
load_in_4bit = True 

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

Bước 3: Cấu hình LoRA Adapters

Tiếp theo, chúng ta xác định các tham số LoRA. Bản thực thi của Unsloth tích hợp các adapter này vào đồ thị tính toán hiệu quả hơn so với các phương pháp thông thường.

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, 
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth", # Quan trọng để tiết kiệm VRAM
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

Bước 4: Chuẩn bị dữ liệu

Định dạng dữ liệu chính xác là điều thiết yếu. Đối với Llama 3, bạn phải sử dụng chat template cụ thể để đảm bảo mô hình tuân thủ các hướng dẫn một cách chính xác sau khi huấn luyện.

from datasets import load_dataset

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Tuân thủ nghiêm ngặt cấu trúc prompt của Llama 3
        text = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{instruction} {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{output}<|eot_id|>"
        texts.append(text)
    return { "text" : texts, }

dataset = load_dataset("json", data_files="my_data.jsonl", split="train")
dataset = dataset.map(formatting_prompts_func, batched = True,)

Bước 5: Thực thi Trainer

Chúng ta sử dụng SFT (Supervised Fine-tuning) Trainer. Unsloth đã tinh chỉnh công cụ này để đảm bảo nó sử dụng backend tốc độ cao của họ trong suốt vòng lặp huấn luyện.

Share: