Skip to content

Conversation

@jcaip
Copy link
Contributor

@jcaip jcaip commented Dec 8, 2025

This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API.

You can use the new flow like follows:

from torchao.quantization.quant_api import (
    Int8StaticActivationInt8WeightConfig,
)
from torchao.prototype.smoothquant import (
    SmoothQuantConfig
)

config = SmoothQuantConfig(
            base_config=Int8StaticActivationInt8Weight(granularity=PerRow()),
            step=SmoothQuantStep.PREPARE,
            alpha=0.5,
        )

quantize_(model, config)

# Perform calibration with test data
model(*x)

config.step = SmoothQuantStep.CONVERT
quantize_(model, config)

# model will now be statically quantized with the inputs used in smoothquant observer. 
model(*x)

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3468

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending

As of commit 0c23589 with merge base f99105a (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 8, 2025
@jcaip jcaip added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Dec 8, 2025
@jcaip jcaip changed the title [wip] enable smoothquant for int8 static tensor enable smoothquant for int8 static tensor Dec 8, 2025
@jcaip jcaip marked this pull request as ready for review December 8, 2025 22:24
@jcaip jcaip requested a review from jerryzh168 December 8, 2025 22:29
@jcaip
Copy link
Contributor Author

jcaip commented Dec 8, 2025

cc @Xia-Weiwen and @cyxlily fyi

qw = quant_mod.weight

# Add smoothing factor metadata
qw = to_weight_tensor_with_linear_activation_scale_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not be using this, please check awq on how this should be implemented in the new stack:

assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the
# reciprocal of the `equalization_scale`
qw.act_pre_scale = 1.0 / equalization_scale

"""

scale: torch.Tensor
scale: torch.Tensor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Optional[torch.Tensor]

[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
# TODO: not sure if we should allow not passing scales as part of static config?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think it's fine

side note: we may need a separate API/flow for plain static quant without Smoothquant if needed.

@jcaip jcaip changed the base branch from jcaip/static-quant-rebased to main December 9, 2025 04:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants