Simple Diffusion Language Modeling
dLLM is a library that unifies the training and evaluation of diffusion language models, bringing transparency and reproducibility to the entire development pipeline:
-
dLLM provides scalable training pipelines (inspired by
transformersTrainer), with support for LoRA, DeepSpeed and FSDP and beyond. -
dLLM provides unified evaluation pipelines (inspired by
lm-evaluation-harness) that abstracts away inference details and making customization simple. -
Built on these components, dLLM provide the minimal pretraining / finetuning / evaluation recipes for open-weight models (e.g., LLaDA and Dream), and implementations of training algorithms (e.g., Edit Flows).
[2025/11] We released a collection of BERTs finetuned for instruction-following: ModernBERT-{large,base}-chat-v0. This proof-of-concept shows that BERTβs internal knowledge can be leveraged for generative tasks via masked instruction tuning. See BERT Chat Report for detailed recipes, experimental results and lessons learned; See
examples/bert for training / inference / evaluation instructions.
-
examples/llada: Pretraining, finetuning and evaluating LLaDA LLaDA / LLaDA-MoE. -
examples/dream: Pretraining, finetuning and evaluating Dream Dream. -
examples/bert: Finetuning any BERT to be lightweight Chatbots. -
examples/editflow: Educational reference for training EditFlow models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with edit operationsβinsertion, deletion, and substitutionβand how to pretrain or finetune EditFlow models from scratch on public data. -
More upcoming.
# create and activate conda environment
conda create -n dllm python=3.10 -y
conda activate dllm
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
conda install cuda=12.4 -c nvidia
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
--index-url https://download.pytorch.org/whl/cu124
# install dllm package
pip install -e .# initialize `lm-evaluation-harness` submodule
git submodule update --init --recursive
# install submodule in editable mode with IFEval & Math dependencies
pip install -e "lm-evaluation-harness[ifeval,math]"For Slurm users, update scripts/train.slurm.sh for your cluster:
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
+ #SBATCH --partition=YOUR_PARTITION
+ #SBATCH --quotatype=YOUR_QUOTATYPENext, create a directory for your job logs:
mkdir logsThis folder will store the log files generated by your sbatch jobs.
# modules for training / sampling
dllm
βββ core # Core reusable modules shared across `dllm/pipelines`
β βββ generation
β βββ schedulers
β βββ trainers
βββ data
βββ pipelines # Application-specific training & inference pipelines
| βββ bert
β βββ dream
β βββ editflow
β βββ llada
β βββ models # Model architecture and configs
β βββ generator.py # Generation utilities
β βββ trainer.py # Core training logic
β βββ eval.py # Evaluation entry point
βββ tools
βββ utils
# entry points for training / sampling
examples
βββ bert
βββ dream
βββ editflow
βββ llada
βββ chat.py # Interactive inference example
βββ generate.py # Inference example
βββ pt.py # Pretraining example
βββ README.md # Documentation (you are here)
βββ sft.py # Supervised finetuning example
βββ eval.sh # Evalution script
A typical training entry script looks like (for example, examples/llada/sft.py) looks like this:
import transformers
import dllm
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
dataset = "..."
# ----- Training --------------------------------------------------------------
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id,
),
)
trainer.train()You can launch training job locally with accelerate, or submit it to a Slurm cluster using sbatch.
# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/llada/sft.py \
--num_train_epochs 4 \
--load_in_4bit True --lora True# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4See Features for specific training recipes.
Here are some useful tips for training:
- Use a subset of data:
--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"- Concatenate datasets:
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"- Train with LoRA and 4bit quantization:
--load_in_4bit True --lora True- Train with different distributed training methods:
--accelerate_config "ddp,zero-{1,2,3},fsdp"
We provide unified generators that abstracts away inference details.
A typical inference entry script (for example, examples/llada/generate.py) looks like this:
import dllm
from dllm import llada
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
# for other models, change your generator and keep others unchanged
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)You can also try interactive chat script (for example, examples/llada/chat.py) for visualized multi-turn dialogue:
Read (optional) Evaluation setup before running evaluation.
For example, to evaluate LLaDA-8B-Instruct on MMLU_Pro, run:
accelerate launch --num_processes 4 \
dllm/pipelines/llada/eval.py \
--tasks "mmlu_pro" \
--model "llada" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"We also provide scripts to automatically evaluate LLaDA, Dream, and BERT-Chat on all benchmarks.
For example, you can launch examples/llada/eval.sh directly using the following commands:
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False@misc{dllm,
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
title = {dLLM: Simple Diffusion Language Modeling},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
}



