Fine-Tuning LLMs at Scale With Databricks MLflow and Spark
Learn how Databricks, Apache Spark, MLflow, and Hugging Face Transformers work together to create an end-to-end fine-tuning platform.
Join the DZone community and get the full member experience.
Join For FreeWhy Fine-Tune on Databricks?
General-purpose LLMs like Llama 3, Mistral, or Falcon are impressive out of the box — but they underperform on domain-specific tasks: medical coding, legal clause extraction, internal support ticket classification, and financial report summarization. Fine-tuning adapts a pre-trained model's weights to your domain using your proprietary labeled data.
Doing this at scale introduces real engineering challenges:
- Training data lives in Delta Lake across dozens of tables
- GPU clusters need to be orchestrated, not hand-managed
- Experiment tracking must be reproducible and auditable
- Models need a promotion workflow before they touch production traffic
Databricks solves all of this in one platform:
- Apache Spark for large-scale data preparation
- MLflow (built-in) for experiment tracking, model registry, and lineage
- Databricks Model Serving for one-click deployment with auto-scaling
- Unity Catalog for governed model and data access
The ML Lifecycle Architecture
Training Pipeline: End-to-End Flow
The flow below shows how a single training run moves through the system — from a triggered job to a promoted model alias.
Environment Setup
# Databricks Runtime ML 14.x+ recommended (ships CUDA, PyTorch, Transformers)
# Install additional packages in your cluster init script or notebook
%pip install \
transformers==4.40.0 \
peft==0.10.0 \
trl==0.8.6 \
accelerate==0.29.3 \
horovod[spark]==0.28.1 \
datasets==2.19.0 \
evaluate==0.4.1 \
--quiet
dbutils.library.restartPython()
import os
import mlflow
import mlflow.transformers
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, TaskType
from pyspark.sql import functions as F
from datasets import Dataset
# ── MLflow setup ──────────────────────────────────────────────────────────────
# On Databricks, MLflow tracking URI is pre-configured to the workspace
# mlflow.set_tracking_uri("databricks") # uncomment for external clusters
EXPERIMENT_NAME = "/Users/[email protected]/llm-finetuning/support-classifier"
mlflow.set_experiment(EXPERIMENT_NAME)
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
CATALOG = "prod"
GOLD_DB = f"{CATALOG}.gold"
MODEL_NAME = f"{CATALOG}.ml.support_intent_classifier" # Unity Catalog model path
print(f"GPU available: {torch.cuda.is_available()}")
print(f"Device count: {torch.cuda.device_count()}")
Preparing Training Data With Spark
Spark handles the heavy lifting before training: filtering noisy records, formatting prompt-response pairs, and splitting the dataset. This stage runs on the CPU cluster — GPU nodes only spin up for the actual training job.
# ── Spark Data Preparation ────────────────────────────────────────────────────
def build_prompt(row):
"""
Format a support conversation into an instruction-following prompt.
Uses the Mistral instruct template: [INST] ... [/INST]
"""
return f"[INST] Classify the intent of this support message:\n\n{row['message']} [/INST] {row['intent_label']}"
# Load from Delta Gold table
raw_df = (
spark.table(f"{GOLD_DB}.support_conversations")
.filter(F.col("quality_score") >= 0.85) # keep high-quality labels only
.filter(F.col("intent_label").isNotNull())
.filter(F.length("message") > 20) # filter empty/stub messages
.filter(F.length("message") < 2048) # filter messages too long to tokenize
.dropDuplicates(["message_hash"]) # remove exact duplicates
.select("message", "intent_label", "created_date")
.limit(500_000) # cap for this training run
)
print(f"Training candidates: {raw_df.count():,}")
# Build prompt strings using Spark — parallelized across all workers
prompt_udf = F.udf(
lambda msg, label: f"[INST] Classify the intent of this support message:\n\n{msg} [/INST] {label}",
returnType="string"
)
prepared_df = (
raw_df
.withColumn("prompt", prompt_udf(F.col("message"), F.col("intent_label")))
.withColumn("token_count",
F.size(F.split(F.col("prompt"), r"\s+"))) # rough word count proxy
.filter(F.col("token_count") < 512) # stay within model context
.select("prompt", "token_count", "created_date")
)
# Stratified split using Spark (reproducible with seed)
train_df, val_df, test_df = prepared_df.randomSplit([0.80, 0.10, 0.10], seed=42)
# Persist splits to Delta for lineage + reproducibility
train_df.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_DB}.llm_train_split")
val_df.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_DB}.llm_val_split")
test_df.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_DB}.llm_test_split")
print(f"Train: {train_df.count():,} | Val: {val_df.count():,} | Test: {test_df.count():,}")
Fine-Tuning With Hugging Face + MLflow Tracking
We use LoRA (Low-Rank Adaptation) — a parameter-efficient fine-tuning technique that freezes the base model and only trains a small set of adapter matrices. This cuts GPU memory requirements by ~70% compared to full fine-tuning, making 7B parameter models trainable on a single A100.
# ── LoRA Fine-Tuning with MLflow Autolog ─────────────────────────────────────
# Convert Spark DataFrame to Hugging Face Dataset
train_pd = spark.table(f"{GOLD_DB}.llm_train_split").select("prompt").toPandas()
val_pd = spark.table(f"{GOLD_DB}.llm_val_split").select("prompt").toPandas()
hf_train = Dataset.from_pandas(train_pd)
hf_val = Dataset.from_pandas(val_pd)
# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, padding_side="right")
tokenizer.pad_token = tokenizer.eos_token
def tokenize(batch):
return tokenizer(
batch["prompt"],
truncation=True,
max_length=512,
padding="max_length",
)
hf_train_tok = hf_train.map(tokenize, batched=True, remove_columns=["prompt"])
hf_val_tok = hf_val.map(tokenize, batched=True, remove_columns=["prompt"])
# Load base model in 4-bit quantization (QLoRA)
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
# Apply LoRA adapter config
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # rank — higher = more capacity, more memory
lora_alpha=32, # scaling factor
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"], # attention layers to adapt
bias="none",
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# Typical output: trainable params: 13,631,488 || all params: 3,765,522,432 || trainable: 0.36%
# Training arguments
training_args = TrainingArguments(
output_dir="/dbfs/tmp/llm-finetune/checkpoints",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8, # effective batch size = 32
warmup_ratio=0.03,
learning_rate=2e-4,
fp16=False,
bf16=True, # use bfloat16 on A100/H100
logging_steps=50,
eval_strategy="steps",
eval_steps=200,
save_strategy="steps",
save_steps=200,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="mlflow", # pipe all metrics to MLflow automatically
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=hf_train_tok,
eval_dataset=hf_val_tok,
tokenizer=tokenizer,
data_collator=data_collator,
)
# ── MLflow Run ────────────────────────────────────────────────────────────────
with mlflow.start_run(run_name="mistral-7b-lora-v1") as run:
# Log hyperparameters manually for full auditability
mlflow.log_params({
"base_model": BASE_MODEL,
"lora_rank": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"target_modules": str(lora_config.target_modules),
"quantization": "4-bit QLoRA (nf4)",
"train_samples": len(hf_train_tok),
"val_samples": len(hf_val_tok),
"epochs": training_args.num_train_epochs,
"effective_batch": training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps,
"learning_rate": training_args.learning_rate,
})
# Train — metrics auto-logged to MLflow via report_to="mlflow"
trainer.train()
# Log final eval metrics explicitly
eval_results = trainer.evaluate()
mlflow.log_metrics({
"final_eval_loss": eval_results["eval_loss"],
"final_eval_perplexity": torch.exp(torch.tensor(eval_results["eval_loss"])).item(),
})
# Log the model + tokenizer as a single MLflow artifact
mlflow.transformers.log_model(
transformers_model={"model": trainer.model, "tokenizer": tokenizer},
artifact_path="model",
task="text-generation",
registered_model_name=MODEL_NAME, # auto-registers to Unity Catalog
metadata={"base_model": BASE_MODEL, "finetuning": "QLoRA"},
)
run_id = run.info.run_id
print(f"Run ID: {run_id}")
print(f"Eval Loss: {eval_results['eval_loss']:.4f}")
Distributed Training With Horovod on Spark
For datasets beyond a few million tokens, or when you need to fine-tune models larger than 13B parameters, single-node training hits GPU memory walls. Horovod distributes training across multiple GPU workers using ring-allreduce — each worker holds a full model replica, and gradients are averaged across workers after every backward pass.
# ── Distributed Fine-Tuning with Horovod on Spark ────────────────────────────
# Best for: datasets > 5M tokens, models > 13B params, or when you need
# to reduce wall-clock training time below a business SLA.
import horovod.torch as hvd
from sparkdl import HorovodRunner
def train_fn(hparams):
"""
Training function executed on each Horovod worker.
Each worker trains on a data shard; gradients are averaged across workers.
"""
import horovod.torch as hvd
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_from_disk
hvd.init()
# Each worker loads only its shard
local_rank = hvd.local_rank()
world_size = hvd.size()
torch.cuda.set_device(local_rank)
# Load dataset shard for this worker
dataset = load_from_disk(f"/dbfs/tmp/llm-finetune/train_shards/shard_{local_rank}")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
).to(f"cuda:{local_rank}")
# Wrap optimizer with Horovod DistributedOptimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=hparams["lr"])
optimizer = hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
compression=hvd.Compression.fp16, # compress gradient communication
)
# Broadcast initial model weights from rank 0 to all workers
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
training_args = TrainingArguments(
output_dir=f"/dbfs/tmp/llm-finetune/hvd_output",
num_train_epochs=hparams["epochs"],
per_device_train_batch_size=hparams["batch_size"],
bf16=True,
no_cuda=False,
dataloader_num_workers=2,
# Only rank 0 logs and saves — avoids duplicated artifacts
report_to="mlflow" if hvd.rank() == 0 else "none",
save_strategy="epoch" if hvd.rank() == 0 else "no",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
optimizers=(optimizer, None),
)
trainer.train()
# Only rank 0 registers the model
if hvd.rank() == 0:
mlflow.transformers.log_model(
transformers_model={"model": model, "tokenizer": tokenizer},
artifact_path="model",
registered_model_name=MODEL_NAME,
)
# Launch distributed training across N GPU workers
# np = number of processes = number of GPUs across all nodes
hr = HorovodRunner(np=8, driver_log_verbosity="all") # 8 GPUs (e.g., 2 × 4-GPU nodes)
hr.run(train_fn, hparams={
"lr": 2e-5,
"epochs": 3,
"batch_size": 2, # per GPU; effective = 2 × 8 = 16
})
MLflow Model Registry and Promotion
Once a run completes, models land in the MLflow Model Registry. Databricks uses Unity Catalog-backed model aliases (candidate, staging, champion) instead of the legacy stage model.
# ── Model Registry Promotion Workflow ─────────────────────────────────────────
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Get the latest registered version from the training run
latest_version = client.get_registered_model(MODEL_NAME).latest_versions[0].version
# Tag the new version as a candidate for review
client.set_registered_model_alias(
name=MODEL_NAME,
alias="candidate",
version=latest_version,
)
client.set_model_version_tag(
name=MODEL_NAME,
version=latest_version,
key="fine_tuned_on",
value="gold.support_conversations",
)
client.set_model_version_tag(
name=MODEL_NAME,
version=latest_version,
key="eval_loss",
value=str(round(eval_results["eval_loss"], 4)),
)
# After human review / automated eval gates pass → promote to staging
client.set_registered_model_alias(
name=MODEL_NAME,
alias="staging",
version=latest_version,
)
# After integration tests pass → promote to champion (production)
client.set_registered_model_alias(
name=MODEL_NAME,
alias="champion",
version=latest_version,
)
# Load model by alias — decouples code from version numbers
champion_model = mlflow.transformers.load_model(f"models:/{MODEL_NAME}@champion")
Serving With Databricks Model Serving
# ── Deploy to Databricks Model Serving ────────────────────────────────────────
# Can also be done via the UI: Models > Serving > Create Endpoint
import requests, json
WORKSPACE_URL = "https://<your-workspace>.azuredatabricks.net"
TOKEN = dbutils.secrets.get("prod-scope", "databricks-token")
endpoint_config = {
"name": "support-intent-classifier",
"config": {
"served_models": [
{
"name": "mistral-7b-lora-champion",
"model_name": MODEL_NAME,
"model_version": latest_version,
"workload_size": "Small", # 1 GPU
"scale_to_zero_enabled": True,
"workload_type": "GPU_LARGE", # A10G
}
],
"traffic_config": {
"routes": [
{"served_model_name": "mistral-7b-lora-champion", "traffic_percentage": 100}
]
},
"auto_capture_config": {
"catalog_name": CATALOG,
"schema_name": "ml",
"table_name": "support_classifier_inference_log",
"enabled": True, # log all requests/responses to Delta
}
}
}
response = requests.post(
f"{WORKSPACE_URL}/api/2.0/serving-endpoints",
headers={"Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json"},
data=json.dumps(endpoint_config),
)
print(response.json())
# ── Query the endpoint ────────────────────────────────────────────────────────
def classify_intent(message: str) -> str:
payload = {
"inputs": {"prompt": f"[INST] Classify the intent of this support message:\n\n{message} [/INST]"},
"params": {"max_new_tokens": 50, "temperature": 0.1},
}
resp = requests.post(
f"{WORKSPACE_URL}/serving-endpoints/support-intent-classifier/invocations",
headers={"Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json"},
data=json.dumps(payload),
)
return resp.json()["predictions"][0]
print(classify_intent("My order hasn't arrived and it's been 10 days"))
# → "shipping_delay"
Comparing Fine-Tuning Strategies
| Strategy | GPU Memory | Training Time | Quality vs Full FT | When to Use |
|---|---|---|---|---|
| Full Fine-Tuning | Very High (80GB+) | Slowest | Baseline (100%) | Max quality, large budget |
| LoRA | Medium (24–40GB) | Fast | ~95% | Best general-purpose choice |
| QLoRA (4-bit + LoRA) | Low (10–16GB) | Medium | ~90–93% | Single GPU, cost-sensitive |
| Prefix Tuning | Low | Very Fast | ~80–85% | Minimal compute, quick iteration |
| Prompt Tuning | Very Low | Fastest | ~70–80% | Inference-only, no weight change |
| RLHF / DPO | High | Slowest | Best alignment | Instruction-following quality |
| Distillation | Medium (teacher) | Medium | Varies | Smaller, faster inference model |
Rule of thumb: Start with QLoRA on a single GPU. If eval loss stagnates or quality gates fail, move to LoRA on multi-GPU. Full fine-tuning is only warranted when you have >1M high-quality labeled examples and a measurable business case for the incremental quality gain.
Key Takeaways
- Spark handles data at scale before training even begins — filtering, tokenization, and splitting across millions of records in minutes.
- QLoRA + LoRA makes fine-tuning 7B–13B models accessible on a single A100, reducing memory footprint by ~70% with minimal quality loss.
- MLflow
report_to="mlflow"gives you automatic experiment tracking with zero extra code — every loss curve, gradient norm, and learning rate schedule is captured. - Unity Catalog model aliases (
candidate→staging→champion) replace brittle version-number references in deployment code, making promotions and rollbacks a one-liner. - Auto Capture on Databricks Model Serving logs every inference request and response to a Delta table — giving you a feedback loop to build your next training dataset.
- Horovod on Spark is the right tool when single-node training exceeds your SLA — it leverages your existing Spark cluster without a separate orchestration layer.
References
- Databricks — LLM Fine-Tuning on Databricks
- MLflow — Transformers Flavor Documentation
- Hugging Face PEFT — LoRA & QLoRA
- QLoRA Paper — "QLoRA: Efficient Finetuning of Quantized LLMs" (Dettmers et al., 2023)
- LoRA Paper — "LoRA: Low-Rank Adaptation of Large Language Models" (Hu et al., 2021)
- Databricks — Model Serving (Foundation Model APIs)
- Horovod on Spark — Official Documentation
- Databricks — HorovodRunner API
- Databricks — Inference Tables (Auto Capture)
- "Training language models to follow instructions with human feedback" — InstructGPT / RLHF (OpenAI, 2022)
Published at DZone with permission of Jubin Abhishek Soni. See the original article here.
Opinions expressed by DZone contributors are their own.


Comments