[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 17 12:51:02 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/155951
>From 2b6d91708127fb3da9f648c778037981657c3ad1 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 28 Aug 2025 17:01:09 -0700
Subject: [PATCH 1/2] Add packing of scales for ScaledMFMAOp
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 1 +
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 142 ++++++++++++++++++
mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt | 1 +
mlir/test/Dialect/AMDGPU/canonicalize.mlir | 25 +++
4 files changed, 169 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
+ let hasCanonicalizer = 1;
}
#endif // AMDGPU
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 11a40d663a201..4107ec53a0988 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <cstdint>
#include <limits>
#include <optional>
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ScaledMFMAOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // If this use of a scale has a non zero opsel, packing has already been
+ // done.
+ auto checkIfUnpackable = [&](OpOperand &op) {
+ if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+ switch (op.getOperandNumber()) {
+ case 3:
+ return smfma.getScalesIdxA() != 0;
+ break;
+ case 4:
+ return smfma.getScalesIdxB() != 0;
+ break;
+ default:
+ return true;
+ break;
+ }
+ }
+ };
+
+ auto setOpsel = [&](unsigned idx, int64_t val) {
+ switch (idx) {
+ case 3:
+ return op.setScalesIdxA(val);
+ break;
+ case 4:
+ return op.setScalesIdxB(val);
+ break;
+ default:
+ break;
+ }
+ };
+
+ // Obtain flat index from offsets and shape.
+ auto getIdxFromExtract = [](vector::ExtractOp op) {
+ ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+ int cumul = 1;
+ int idx = 0;
+ for (auto [offset, size] :
+ reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+ idx += offset * cumul;
+ cumul *= size;
+ }
+ return idx;
+ };
+
+ // Obtain offsets for new shape from flat index.
+ auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+ SmallVector<int64_t> res;
+ ShapedType shapedty = static_cast<ShapedType>(ty);
+ int64_t numElements = shapedty.getNumElements();
+ for (auto size : shapedty.getShape()) {
+ numElements /= size;
+ res.push_back(idx / numElements);
+ idx -= (idx / numElements) * size;
+ }
+ return res;
+ };
+
+ // For every scale operand of this ScaledMFMAOp, if the scale follows the
+ // following pattern:
+ //
+ // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
+ // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0] * ...
+ //
+ // rewrite to:
+ //
+ // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
+ // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0-3] * ...
+ //
+ // This creates duplicate shape_casts for every use but these will be removed in CSE.
+ for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+ auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+ if (!insertOp) {
+ return failure();
+ }
+ if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+ return failure();
+ }
+
+ auto extractOp =
+ insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+ if (!extractOp) {
+ return failure();
+ }
+
+ Value scaleSrc = extractOp.getOperand(0);
+ auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+ if (!stype) {
+ return failure();
+ }
+ // We do not handle dynamic dims yet, assume that the input is padded to
+ // a static shape now.
+ if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+ [&](int64_t i) { return stype.isDynamicDim(i); })) {
+ return failure();
+ }
+
+ int64_t numElements = stype.getNumElements();
+ if (numElements <= 4) {
+ return failure();
+ }
+
+ Type newSrcType = VectorType::get(
+ SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+ Value newScaleSrc =
+ rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+ int64_t idx = getIdxFromExtract(extractOp);
+ SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+ auto scaleTy = VectorType::get({4}, stype.getElementType());
+ Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+ SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+ Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
+ op.setOperand(opIdx, scale);
+ setOpsel(opIdx, offsets[1]);
+ }
+ return success();
+ }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<PackScales>(context);
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
+ MLIRVectorDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRMemRefUtils
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..75cbf29c95f29 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+ %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}
>From 3873edac6f205ac98808103bfdb1251eedfadf99 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 17 Sep 2025 14:36:47 -0500
Subject: [PATCH 2/2] PR review round 0
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 26 ++++++++++----------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4107ec53a0988..2e3f95651902e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -653,26 +653,24 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
switch (op.getOperandNumber()) {
case 3:
return smfma.getScalesIdxA() != 0;
- break;
case 4:
return smfma.getScalesIdxB() != 0;
- break;
default:
- return true;
break;
}
}
+ return true;
};
auto setOpsel = [&](unsigned idx, int64_t val) {
switch (idx) {
case 3:
- return op.setScalesIdxA(val);
+ op.setScalesIdxA(val);
break;
case 4:
- return op.setScalesIdxB(val);
+ op.setScalesIdxB(val);
break;
- default:
+ default:
break;
}
};
@@ -695,7 +693,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
SmallVector<int64_t> res;
ShapedType shapedty = static_cast<ShapedType>(ty);
int64_t numElements = shapedty.getNumElements();
- for (auto size : shapedty.getShape()) {
+ for (unsigned size : shapedty.getShape()) {
numElements /= size;
res.push_back(idx / numElements);
idx -= (idx / numElements) * size;
@@ -706,17 +704,19 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
// For every scale operand of this ScaledMFMAOp, if the scale follows the
// following pattern:
//
- // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
- // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
- // amdgpu.scaled_mfma(%scale[0] * ...
+ // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
+ // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
+ // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
//
// rewrite to:
//
- // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
- // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+ // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
+ // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
+ // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
// amdgpu.scaled_mfma(%scale[0-3] * ...
//
- // This creates duplicate shape_casts for every use but these will be removed in CSE.
+ // This creates duplicate shape_casts for every use but these will be
+ // removed in CSE.
for (auto opIdx : SmallVector<int64_t>({3, 4})) {
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
if (!insertOp) {
More information about the Mlir-commits
mailing list