DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Refcards Trend Reports
Events Video Library
Refcards
Trend Reports

Events

View Events Video Library

Related

  • Rust-Native Alternatives to Spark SQL and DataFrame Workloads
  • Boost Your Spark Jobs: How Photon Accelerates Apache Spark Performance
  • Apache Spark 3 to Apache Spark 4 Migration: What Breaks, What Improves, What's Mandatory
  • Hadoop on AmpereOne Reference Architecture

Trending

  • AI, OAuth, and Other Platform APIs in the Core
  • Two Clocks Are Running Out at Once, and Almost Nobody Is Watching Both
  • How to Set MX Records via API: Automate Email Routing Programmatically
  • The Breach Was Never at the Door
  1. DZone
  2. Data Engineering
  3. Big Data
  4. Fine-Tuning LLMs at Scale With Databricks MLflow and Spark

Fine-Tuning LLMs at Scale With Databricks MLflow and Spark

Learn how Databricks, Apache Spark, MLflow, and Hugging Face Transformers work together to create an end-to-end fine-tuning platform.

By 
Jubin Abhishek Soni user avatar
Jubin Abhishek Soni
DZone Core CORE ·
Jun. 30, 26 · Analysis
Likes (0)
Comment
Save
Tweet
Share
82 Views

Join the DZone community and get the full member experience.

Join For Free

Why Fine-Tune on Databricks?

General-purpose LLMs like Llama 3, Mistral, or Falcon are impressive out of the box — but they underperform on domain-specific tasks: medical coding, legal clause extraction, internal support ticket classification, and financial report summarization. Fine-tuning adapts a pre-trained model's weights to your domain using your proprietary labeled data.

Doing this at scale introduces real engineering challenges:

  • Training data lives in Delta Lake across dozens of tables
  • GPU clusters need to be orchestrated, not hand-managed
  • Experiment tracking must be reproducible and auditable
  • Models need a promotion workflow before they touch production traffic

Databricks solves all of this in one platform:

  • Apache Spark for large-scale data preparation
  • MLflow (built-in) for experiment tracking, model registry, and lineage
  • Databricks Model Serving for one-click deployment with auto-scaling
  • Unity Catalog for governed model and data access

The ML Lifecycle Architecture

ML description

Training Pipeline: End-to-End Flow

The flow below shows how a single training run moves through the system — from a triggered job to a promoted model alias.

Training description

Environment Setup

Python
 
# Databricks Runtime ML 14.x+ recommended (ships CUDA, PyTorch, Transformers)
# Install additional packages in your cluster init script or notebook

%pip install \
    transformers==4.40.0 \
    peft==0.10.0 \
    trl==0.8.6 \
    accelerate==0.29.3 \
    horovod[spark]==0.28.1 \
    datasets==2.19.0 \
    evaluate==0.4.1 \
    --quiet

dbutils.library.restartPython()

import os
import mlflow
import mlflow.transformers
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, TaskType
from pyspark.sql import functions as F
from datasets import Dataset

# ── MLflow setup ──────────────────────────────────────────────────────────────
# On Databricks, MLflow tracking URI is pre-configured to the workspace
# mlflow.set_tracking_uri("databricks")   # uncomment for external clusters

EXPERIMENT_NAME = "/Users/[email protected]/llm-finetuning/support-classifier"
mlflow.set_experiment(EXPERIMENT_NAME)

BASE_MODEL   = "mistralai/Mistral-7B-Instruct-v0.2"
CATALOG      = "prod"
GOLD_DB      = f"{CATALOG}.gold"
MODEL_NAME   = f"{CATALOG}.ml.support_intent_classifier"   # Unity Catalog model path

print(f"GPU available: {torch.cuda.is_available()}")
print(f"Device count:  {torch.cuda.device_count()}")


Preparing Training Data With Spark

Spark handles the heavy lifting before training: filtering noisy records, formatting prompt-response pairs, and splitting the dataset. This stage runs on the CPU cluster — GPU nodes only spin up for the actual training job.

Plain Text
 
# ── Spark Data Preparation ────────────────────────────────────────────────────

def build_prompt(row):
    """
    Format a support conversation into an instruction-following prompt.
    Uses the Mistral instruct template: [INST] ... [/INST]
    """
    return f"[INST] Classify the intent of this support message:\n\n{row['message']} [/INST] {row['intent_label']}"


# Load from Delta Gold table
raw_df = (
    spark.table(f"{GOLD_DB}.support_conversations")
        .filter(F.col("quality_score") >= 0.85)           # keep high-quality labels only
        .filter(F.col("intent_label").isNotNull())
        .filter(F.length("message") > 20)                 # filter empty/stub messages
        .filter(F.length("message") < 2048)               # filter messages too long to tokenize
        .dropDuplicates(["message_hash"])                  # remove exact duplicates
        .select("message", "intent_label", "created_date")
        .limit(500_000)                                    # cap for this training run
)

print(f"Training candidates: {raw_df.count():,}")

# Build prompt strings using Spark — parallelized across all workers
prompt_udf = F.udf(
    lambda msg, label: f"[INST] Classify the intent of this support message:\n\n{msg} [/INST] {label}",
    returnType="string"
)

prepared_df = (
    raw_df
        .withColumn("prompt", prompt_udf(F.col("message"), F.col("intent_label")))
        .withColumn("token_count",
            F.size(F.split(F.col("prompt"), r"\s+")))         # rough word count proxy
        .filter(F.col("token_count") < 512)                   # stay within model context
        .select("prompt", "token_count", "created_date")
)

# Stratified split using Spark (reproducible with seed)
train_df, val_df, test_df = prepared_df.randomSplit([0.80, 0.10, 0.10], seed=42)

# Persist splits to Delta for lineage + reproducibility
train_df.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_DB}.llm_train_split")
val_df.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_DB}.llm_val_split")
test_df.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_DB}.llm_test_split")

print(f"Train: {train_df.count():,} | Val: {val_df.count():,} | Test: {test_df.count():,}")


Fine-Tuning With Hugging Face + MLflow Tracking

We use LoRA (Low-Rank Adaptation) — a parameter-efficient fine-tuning technique that freezes the base model and only trains a small set of adapter matrices. This cuts GPU memory requirements by ~70% compared to full fine-tuning, making 7B parameter models trainable on a single A100.

Python
 
# ── LoRA Fine-Tuning with MLflow Autolog ─────────────────────────────────────

# Convert Spark DataFrame to Hugging Face Dataset
train_pd  = spark.table(f"{GOLD_DB}.llm_train_split").select("prompt").toPandas()
val_pd    = spark.table(f"{GOLD_DB}.llm_val_split").select("prompt").toPandas()

hf_train  = Dataset.from_pandas(train_pd)
hf_val    = Dataset.from_pandas(val_pd)

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, padding_side="right")
tokenizer.pad_token = tokenizer.eos_token

def tokenize(batch):
    return tokenizer(
        batch["prompt"],
        truncation=True,
        max_length=512,
        padding="max_length",
    )

hf_train_tok = hf_train.map(tokenize, batched=True, remove_columns=["prompt"])
hf_val_tok   = hf_val.map(tokenize,   batched=True, remove_columns=["prompt"])

# Load base model in 4-bit quantization (QLoRA)
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Apply LoRA adapter config
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,                    # rank — higher = more capacity, more memory
    lora_alpha=32,           # scaling factor
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],   # attention layers to adapt
    bias="none",
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# Typical output: trainable params: 13,631,488 || all params: 3,765,522,432 || trainable: 0.36%

# Training arguments
training_args = TrainingArguments(
    output_dir="/dbfs/tmp/llm-finetune/checkpoints",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,           # effective batch size = 32
    warmup_ratio=0.03,
    learning_rate=2e-4,
    fp16=False,
    bf16=True,                               # use bfloat16 on A100/H100
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to="mlflow",                      #  pipe all metrics to MLflow automatically
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_train_tok,
    eval_dataset=hf_val_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# ── MLflow Run ────────────────────────────────────────────────────────────────
with mlflow.start_run(run_name="mistral-7b-lora-v1") as run:
    # Log hyperparameters manually for full auditability
    mlflow.log_params({
        "base_model":          BASE_MODEL,
        "lora_rank":           lora_config.r,
        "lora_alpha":          lora_config.lora_alpha,
        "lora_dropout":        lora_config.lora_dropout,
        "target_modules":      str(lora_config.target_modules),
        "quantization":        "4-bit QLoRA (nf4)",
        "train_samples":       len(hf_train_tok),
        "val_samples":         len(hf_val_tok),
        "epochs":              training_args.num_train_epochs,
        "effective_batch":     training_args.per_device_train_batch_size
                               * training_args.gradient_accumulation_steps,
        "learning_rate":       training_args.learning_rate,
    })

    # Train — metrics auto-logged to MLflow via report_to="mlflow"
    trainer.train()

    # Log final eval metrics explicitly
    eval_results = trainer.evaluate()
    mlflow.log_metrics({
        "final_eval_loss":       eval_results["eval_loss"],
        "final_eval_perplexity": torch.exp(torch.tensor(eval_results["eval_loss"])).item(),
    })

    # Log the model + tokenizer as a single MLflow artifact
    mlflow.transformers.log_model(
        transformers_model={"model": trainer.model, "tokenizer": tokenizer},
        artifact_path="model",
        task="text-generation",
        registered_model_name=MODEL_NAME,    # auto-registers to Unity Catalog
        metadata={"base_model": BASE_MODEL, "finetuning": "QLoRA"},
    )

    run_id = run.info.run_id
    print(f"Run ID: {run_id}")
    print(f"Eval Loss: {eval_results['eval_loss']:.4f}")


Distributed Training With Horovod on Spark

For datasets beyond a few million tokens, or when you need to fine-tune models larger than 13B parameters, single-node training hits GPU memory walls. Horovod distributes training across multiple GPU workers using ring-allreduce — each worker holds a full model replica, and gradients are averaged across workers after every backward pass.

Python
 
# ── Distributed Fine-Tuning with Horovod on Spark ────────────────────────────
# Best for: datasets > 5M tokens, models > 13B params, or when you need
# to reduce wall-clock training time below a business SLA.

import horovod.torch as hvd
from sparkdl import HorovodRunner

def train_fn(hparams):
    """
    Training function executed on each Horovod worker.
    Each worker trains on a data shard; gradients are averaged across workers.
    """
    import horovod.torch as hvd
    from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
    from datasets import load_from_disk

    hvd.init()

    # Each worker loads only its shard
    local_rank = hvd.local_rank()
    world_size = hvd.size()

    torch.cuda.set_device(local_rank)

    # Load dataset shard for this worker
    dataset = load_from_disk(f"/dbfs/tmp/llm-finetune/train_shards/shard_{local_rank}")

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.bfloat16,
    ).to(f"cuda:{local_rank}")

    # Wrap optimizer with Horovod DistributedOptimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=hparams["lr"])
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=hvd.Compression.fp16,    # compress gradient communication
    )

    # Broadcast initial model weights from rank 0 to all workers
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    training_args = TrainingArguments(
        output_dir=f"/dbfs/tmp/llm-finetune/hvd_output",
        num_train_epochs=hparams["epochs"],
        per_device_train_batch_size=hparams["batch_size"],
        bf16=True,
        no_cuda=False,
        dataloader_num_workers=2,
        # Only rank 0 logs and saves — avoids duplicated artifacts
        report_to="mlflow" if hvd.rank() == 0 else "none",
        save_strategy="epoch" if hvd.rank() == 0 else "no",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        optimizers=(optimizer, None),
    )
    trainer.train()

    # Only rank 0 registers the model
    if hvd.rank() == 0:
        mlflow.transformers.log_model(
            transformers_model={"model": model, "tokenizer": tokenizer},
            artifact_path="model",
            registered_model_name=MODEL_NAME,
        )


# Launch distributed training across N GPU workers
# np = number of processes = number of GPUs across all nodes
hr = HorovodRunner(np=8, driver_log_verbosity="all")   # 8 GPUs (e.g., 2 × 4-GPU nodes)
hr.run(train_fn, hparams={
    "lr":         2e-5,
    "epochs":     3,
    "batch_size": 2,       # per GPU; effective = 2 × 8 = 16
})


MLflow Model Registry and Promotion

Once a run completes, models land in the MLflow Model Registry. Databricks uses Unity Catalog-backed model aliases (candidate, staging, champion) instead of the legacy stage model.

Python
 
# ── Model Registry Promotion Workflow ─────────────────────────────────────────

from mlflow.tracking import MlflowClient

client = MlflowClient()

# Get the latest registered version from the training run
latest_version = client.get_registered_model(MODEL_NAME).latest_versions[0].version

# Tag the new version as a candidate for review
client.set_registered_model_alias(
    name=MODEL_NAME,
    alias="candidate",
    version=latest_version,
)

client.set_model_version_tag(
    name=MODEL_NAME,
    version=latest_version,
    key="fine_tuned_on",
    value="gold.support_conversations",
)

client.set_model_version_tag(
    name=MODEL_NAME,
    version=latest_version,
    key="eval_loss",
    value=str(round(eval_results["eval_loss"], 4)),
)

# After human review / automated eval gates pass → promote to staging
client.set_registered_model_alias(
    name=MODEL_NAME,
    alias="staging",
    version=latest_version,
)

# After integration tests pass → promote to champion (production)
client.set_registered_model_alias(
    name=MODEL_NAME,
    alias="champion",
    version=latest_version,
)

# Load model by alias — decouples code from version numbers
champion_model = mlflow.transformers.load_model(f"models:/{MODEL_NAME}@champion")


Serving With Databricks Model Serving

Python
 
# ── Deploy to Databricks Model Serving ────────────────────────────────────────
# Can also be done via the UI: Models > Serving > Create Endpoint

import requests, json

WORKSPACE_URL = "https://<your-workspace>.azuredatabricks.net"
TOKEN         = dbutils.secrets.get("prod-scope", "databricks-token")

endpoint_config = {
    "name": "support-intent-classifier",
    "config": {
        "served_models": [
            {
                "name":                    "mistral-7b-lora-champion",
                "model_name":              MODEL_NAME,
                "model_version":           latest_version,
                "workload_size":           "Small",      # 1 GPU
                "scale_to_zero_enabled":   True,
                "workload_type":           "GPU_LARGE",  # A10G
            }
        ],
        "traffic_config": {
            "routes": [
                {"served_model_name": "mistral-7b-lora-champion", "traffic_percentage": 100}
            ]
        },
        "auto_capture_config": {
            "catalog_name":  CATALOG,
            "schema_name":   "ml",
            "table_name":    "support_classifier_inference_log",
            "enabled":       True,                        # log all requests/responses to Delta
        }
    }
}

response = requests.post(
    f"{WORKSPACE_URL}/api/2.0/serving-endpoints",
    headers={"Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json"},
    data=json.dumps(endpoint_config),
)
print(response.json())

# ── Query the endpoint ────────────────────────────────────────────────────────
def classify_intent(message: str) -> str:
    payload = {
        "inputs": {"prompt": f"[INST] Classify the intent of this support message:\n\n{message} [/INST]"},
        "params": {"max_new_tokens": 50, "temperature": 0.1},
    }
    resp = requests.post(
        f"{WORKSPACE_URL}/serving-endpoints/support-intent-classifier/invocations",
        headers={"Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json"},
        data=json.dumps(payload),
    )
    return resp.json()["predictions"][0]

print(classify_intent("My order hasn't arrived and it's been 10 days"))
# → "shipping_delay"


Comparing Fine-Tuning Strategies

Strategy GPU Memory Training Time Quality vs Full FT When to Use
Full Fine-Tuning Very High (80GB+) Slowest Baseline (100%) Max quality, large budget
LoRA Medium (24–40GB) Fast ~95% Best general-purpose choice
QLoRA (4-bit + LoRA) Low (10–16GB) Medium ~90–93% Single GPU, cost-sensitive
Prefix Tuning Low Very Fast ~80–85% Minimal compute, quick iteration
Prompt Tuning Very Low Fastest ~70–80% Inference-only, no weight change
RLHF / DPO High Slowest Best alignment Instruction-following quality
Distillation Medium (teacher) Medium Varies Smaller, faster inference model

Rule of thumb: Start with QLoRA on a single GPU. If eval loss stagnates or quality gates fail, move to LoRA on multi-GPU. Full fine-tuning is only warranted when you have >1M high-quality labeled examples and a measurable business case for the incremental quality gain.

Key Takeaways

  • Spark handles data at scale before training even begins — filtering, tokenization, and splitting across millions of records in minutes.
  • QLoRA + LoRA makes fine-tuning 7B–13B models accessible on a single A100, reducing memory footprint by ~70% with minimal quality loss.
  • MLflow report_to="mlflow" gives you automatic experiment tracking with zero extra code — every loss curve, gradient norm, and learning rate schedule is captured.
  • Unity Catalog model aliases (candidate → staging → champion) replace brittle version-number references in deployment code, making promotions and rollbacks a one-liner.
  • Auto Capture on Databricks Model Serving logs every inference request and response to a Delta table — giving you a feedback loop to build your next training dataset.
  • Horovod on Spark is the right tool when single-node training exceeds your SLA — it leverages your existing Spark cluster without a separate orchestration layer.

References

  1. Databricks — LLM Fine-Tuning on Databricks
  2. MLflow — Transformers Flavor Documentation
  3. Hugging Face PEFT — LoRA & QLoRA
  4. QLoRA Paper — "QLoRA: Efficient Finetuning of Quantized LLMs" (Dettmers et al., 2023)
  5. LoRA Paper — "LoRA: Low-Rank Adaptation of Large Language Models" (Hu et al., 2021)
  6. Databricks — Model Serving (Foundation Model APIs)
  7. Horovod on Spark — Official Documentation
  8. Databricks — HorovodRunner API
  9. Databricks — Inference Tables (Auto Capture)
  10. "Training language models to follow instructions with human feedback" — InstructGPT / RLHF (OpenAI, 2022)
Apache Spark

Published at DZone with permission of Jubin Abhishek Soni. See the original article here.

Opinions expressed by DZone contributors are their own.

Related

  • Rust-Native Alternatives to Spark SQL and DataFrame Workloads
  • Boost Your Spark Jobs: How Photon Accelerates Apache Spark Performance
  • Apache Spark 3 to Apache Spark 4 Migration: What Breaks, What Improves, What's Mandatory
  • Hadoop on AmpereOne Reference Architecture

Partner Resources

×

Comments

The likes didn't load as expected. Please refresh the page and try again.

  • RSS
  • X
  • Facebook

ABOUT US

  • About DZone
  • Support and feedback
  • Community research

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Core Program
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 3343 Perimeter Hill Drive
  • Suite 215
  • Nashville, TN 37211
  • [email protected]

Let's be friends:

  • RSS
  • X
  • Facebook