Skip to content

tianhaoz95/dllm

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

11 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

dLLM

Simple Diffusion Language Modeling

dLLM logo

Overview

dLLM is a library that unifies the training and evaluation of diffusion language models, bringing transparency and reproducibility to the entire development pipeline:

  • dLLM provides scalable training pipelines (inspired by transformers Trainer), with support for LoRA, DeepSpeed and FSDP and beyond.

  • dLLM provides unified evaluation pipelines (inspired by lm-evaluation-harness) that abstracts away inference details and making customization simple.

  • Built on these components, dLLM provide the minimal pretraining / finetuning / evaluation recipes for open-weight models (e.g., LLaDA and Dream), and implementations of training algorithms (e.g., Edit Flows).

News

[2025/11] We released a collection of BERTs finetuned for instruction-following: ModernBERT-{large,base}-chat-v0. This proof-of-concept shows that BERT’s internal knowledge can be leveraged for generative tasks via masked instruction tuning. See blog BERT Chat Report for detailed recipes, experimental results and lessons learned; See examples/bert for training / inference / evaluation instructions.

Table of Contents

Features

  • examples/llada: Pretraining, finetuning and evaluating LLaDA LLaDA / LLaDA-MoE.

  • examples/dream: Pretraining, finetuning and evaluating Dream Dream.

  • examples/bert: Finetuning any BERT to be lightweight Chatbots.

    🎬 Click to show BERT Chat Demo

    chat

    Chat with ModernBERT-large-chat-v0. See Inference for details.

  • examples/editflow: Educational reference for training EditFlow models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with edit operationsβ€”insertion, deletion, and substitutionβ€”and how to pretrain or finetune EditFlow models from scratch on public data.

    🎬 Click to show EditFlow Demo

    EditFlow demo

    EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough β†’ removed) during generation.

  • More upcoming.

Setup

Installation

# create and activate conda environment
conda create -n dllm python=3.10 -y
conda activate dllm

# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
conda install cuda=12.4 -c nvidia
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
    --index-url https://download.pytorch.org/whl/cu124

# install dllm package
pip install -e .

(optional) Evaluation setup

# initialize `lm-evaluation-harness` submodule
git submodule update --init --recursive

# install submodule in editable mode with IFEval & Math dependencies
pip install -e "lm-evaluation-harness[ifeval,math]"

(optional) Slurm setup

For Slurm users, update scripts/train.slurm.sh for your cluster:

- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
- #SBATCH --quotatype=spot        # Note: adjust this for your cluster
+ #SBATCH --partition=YOUR_PARTITION
+ #SBATCH --quotatype=YOUR_QUOTATYPE

Next, create a directory for your job logs:

mkdir logs

This folder will store the log files generated by your sbatch jobs.

Files overview

# modules for training / sampling
dllm
β”œβ”€β”€ core                   # Core reusable modules shared across `dllm/pipelines` 
β”‚   β”œβ”€β”€ generation
β”‚   β”œβ”€β”€ schedulers
β”‚   └── trainers
β”œβ”€β”€ data
β”œβ”€β”€ pipelines              # Application-specific training & inference pipelines
|   β”œβ”€β”€ bert
β”‚   β”œβ”€β”€ dream
β”‚   β”œβ”€β”€ editflow
β”‚   └── llada
β”‚       β”œβ”€β”€ models         # Model architecture and configs 
β”‚       β”œβ”€β”€ generator.py   # Generation utilities
β”‚       β”œβ”€β”€ trainer.py     # Core training logic
β”‚       └── eval.py        # Evaluation entry point
β”œβ”€β”€ tools
└── utils

# entry points for training / sampling
examples
β”œβ”€β”€ bert
β”œβ”€β”€ dream
β”œβ”€β”€ editflow
└── llada
    β”œβ”€β”€ chat.py            # Interactive inference example
    β”œβ”€β”€ generate.py        # Inference example
    β”œβ”€β”€ pt.py              # Pretraining example
    β”œβ”€β”€ README.md          # Documentation (you are here)
    β”œβ”€β”€ sft.py             # Supervised finetuning example
    └── eval.sh            # Evalution script

Training

A typical training entry script looks like (for example, examples/llada/sft.py) looks like this:

import transformers

import dllm

model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
dataset = "..."

# ----- Training --------------------------------------------------------------
trainer = dllm.core.trainers.MDLMTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    args=training_args,
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer,
        return_tensors="pt",
        padding=True,
        label_pad_token_id=tokenizer.pad_token_id, 
    ),
)
trainer.train()

You can launch training job locally with accelerate, or submit it to a Slurm cluster using sbatch.

# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
accelerate launch \
    --config_file scripts/accelerate_configs/zero2.yaml \
    examples/llada/sft.py \
    --num_train_epochs 4 \
    --load_in_4bit True --lora True
# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
sbatch --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/llada/sft.py" \
    --num_train_epochs 4

# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/llada/sft.py" \
    --num_train_epochs 4

See Features for specific training recipes.

Here are some useful tips for training:

  1. Use a subset of data: --dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
  2. Concatenate datasets: --dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"
  3. Train with LoRA and 4bit quantization: --load_in_4bit True --lora True
  4. Train with different distributed training methods: --accelerate_config "ddp,zero-{1,2,3},fsdp"

Inference

We provide unified generators that abstracts away inference details. A typical inference entry script (for example, examples/llada/generate.py) looks like this:

import dllm
from dllm import llada

model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
# for other models, change your generator and keep others unchanged
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)

messages = [
    [{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
    [{"role": "user", "content": "Please write an educational python function."}],
]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
)

outputs = generator.generate(inputs, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)

You can also try interactive chat script (for example, examples/llada/chat.py) for visualized multi-turn dialogue:

chat

Evaluation

Read (optional) Evaluation setup before running evaluation.

For example, to evaluate LLaDA-8B-Instruct on MMLU_Pro, run:

accelerate launch --num_processes 4 \
    dllm/pipelines/llada/eval.py \
    --tasks "mmlu_pro" \
    --model "llada" \
    --apply_chat_template \
    --num_fewshot 0 \
    --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"

We also provide scripts to automatically evaluate LLaDA, Dream, and BERT-Chat on all benchmarks. For example, you can launch examples/llada/eval.sh directly using the following commands:

bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False

Citation

@misc{dllm,
    author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
    title = {dLLM: Simple Diffusion Language Modeling},
    year = {2025},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
}

About

dLLM: Simple Diffusion Language Modeling

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.3%
  • Shell 2.7%