[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Sep 17 13:55:26 PDT 2025
================
@@ -631,6 +633,151 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ScaledMFMAOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // If this use of a scale has a non zero opsel, packing has already been
+ // done.
+ auto checkIfUnpackable = [&](OpOperand &op) {
+ if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+ switch (op.getOperandNumber()) {
+ case 3:
+ return smfma.getScalesIdxA() != 0;
+ case 4:
+ return smfma.getScalesIdxB() != 0;
+ default:
+ break;
+ }
+ }
+ return true;
+ };
+
+ auto setOpsel = [&](unsigned idx, int64_t val) {
+ switch (idx) {
+ case 3:
+ op.setScalesIdxA(val);
+ break;
+ case 4:
+ op.setScalesIdxB(val);
+ break;
+ default:
+ break;
+ }
+ };
+
+ // Obtain flat index from offsets and shape.
+ auto getIdxFromExtract = [](vector::ExtractOp op) {
+ ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+ int64_t cumul = 1;
+ int64_t idx = 0;
+ for (auto [offset, size] :
+ reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+ idx += offset * cumul;
+ cumul *= size;
+ }
+ return idx;
+ };
+
+ // Obtain offsets for new shape from flat index.
+ auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+ SmallVector<int64_t> res;
+ ShapedType shapedty = static_cast<ShapedType>(ty);
+ int64_t numElements = shapedty.getNumElements();
+ for (unsigned size : shapedty.getShape()) {
+ numElements /= size;
+ res.push_back(idx / numElements);
+ idx -= (idx / numElements) * size;
+ }
+ return res;
+ };
+
+ // For every scale operand of this ScaledMFMAOp, if the scale follows the
+ // following pattern:
+ //
+ // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
+ // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
+ // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
+ //
+ // rewrite to:
+ //
+ // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
+ // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
+ // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0-3] * ...
+ //
+ // This creates duplicate shape_casts for every use but these will be
+ // removed in CSE.
+ for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+ auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+ if (!insertOp) {
+ return rewriter.notifyMatchFailure(op,
+ "defining op not a vector.insert");
+ }
+ if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+ return rewriter.notifyMatchFailure(op,
+ "some scaled mfma's already packed");
+ }
+
+ auto extractOp =
+ insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+ if (!extractOp) {
+ return rewriter.notifyMatchFailure(op,
+ "defining op not a vector.extract");
+ }
+
+ Value scaleSrc = extractOp.getOperand(0);
+ auto stype = dyn_cast<VectorType>(scaleSrc.getType());
+ if (!stype) {
+ return rewriter.notifyMatchFailure(op, "not a shaped type");
+ }
+ // We do not handle dynamic dims yet, assume that the input is padded to
+ // a static shape now.
+ if (!stype.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "dynamic dims not yet supported");
+ }
+
+ int64_t numElements = stype.getNumElements();
+ if (numElements <= 4 || !(numElements % 4)) {
+ return rewriter.notifyMatchFailure(
+ op, "no packing if # of scales less than or indivisible by four");
+ }
+
+ Type newSrcType = VectorType::get(
+ SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+ Value newScaleSrc =
+ rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+ int64_t idx = getIdxFromExtract(extractOp);
+ SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+ auto scaleTy = VectorType::get({4}, stype.getElementType());
+ Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+ SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
----------------
kuhar wrote:
Also here, don't use smallvector is all you need is a few elements to be passed as ArrayRef
https://github.com/llvm/llvm-project/pull/155951
More information about the Mlir-commits
mailing list