[Mlir-commits] [mlir] 9628061 - [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (#155951)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 18 12:25:19 PDT 2025


Author: Muzammil
Date: 2025-09-18T19:25:14Z
New Revision: 9628061e055c9f695ff80f9a74e4f6e524b34993

URL: https://github.com/llvm/llvm-project/commit/9628061e055c9f695ff80f9a74e4f6e524b34993
DIFF: https://github.com/llvm/llvm-project/commit/9628061e055c9f695ff80f9a74e4f6e524b34993.diff

LOG: [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (#155951)

The ScaledMFMAOp accepts scales as a vector of 4 bytes
(`vector<4xf8E8M0FNU>`) that can be stored in a single register with a
particular scale accessed using the `OpSel` attribute. Currently, we
only use one byte in this 4-byte vector, resulting in 3 wasted
registers.

This is fixed by identifying when single byte extractions are performed
and rewriting them into extractions of 4-byte vectors.

Example:
```
  %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
  %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
  amdgpu.scaled_mfma(%scale[0] * ...
```
to
```
  %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU> 
  %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
  amdgpu.scaled_mfma(%scale[0-3] * ...
```

---------

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
    mlir/test/Dialect/AMDGPU/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
     attr-dict
     `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
+  let hasCanonicalizer = 1;
 }
 #endif // AMDGPU

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cfaea25fe5918..d5c71905f7b4a 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -26,8 +27,11 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <algorithm>
+#include <cstdint>
 #include <limits>
 #include <optional>
 
@@ -631,6 +635,139 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto setOpsel = [&op](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        op.setScalesIdxA(val);
+        break;
+      case 4:
+        op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // For every scale operand of this ScaledMFMAOp, if the scale is produced by
+    // the extraction of a single scale from some vector, then attempt to
+    // extract 4 values from that vector instead.
+    //
+    // Example: (f8 here means f8E8M0FNU)
+    // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
+    // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
+    // amdgpu.scaled_mfma(%scale[0] * ...
+    //
+    // rewrite to:
+    //
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
+    // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
+    // amdgpu.scaled_mfma(%scale[0-3] * ...
+    //
+    // This creates duplicate shape_casts for every use but these will be
+    // removed in CSE.
+    for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
+      auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+      if (!insertOp) {
+        return rewriter.notifyMatchFailure(op,
+                                           "defining op not a vector.insert");
+      }
+      // If the extracted value is not a single scalar, then it has been packed.
+      if (isa<VectorType>(insertOp.getValueToStore().getType())) {
+        return rewriter.notifyMatchFailure(
+            op, "scaled mfma operand already packed");
+      }
+
+      auto extractOp =
+          insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
+      if (!extractOp) {
+        return rewriter.notifyMatchFailure(op,
+                                           "defining op not a vector.extract");
+      }
+
+      Value scaleSrc = extractOp.getOperand(0);
+      auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
+      if (!scaleSrcType) {
+        return rewriter.notifyMatchFailure(op, "not a vector type");
+      }
+
+      // We do not handle dynamic dims yet, assume that the input is padded to
+      // a static shape now.
+      if (!scaleSrcType.hasStaticShape()) {
+        return rewriter.notifyMatchFailure(op,
+                                           "dynamic dims not yet supported");
+      }
+
+      int64_t numElements = scaleSrcType.getNumElements();
+      if (numElements <= 4) {
+        return rewriter.notifyMatchFailure(
+            op, "no packing if # of scales less than four");
+      }
+
+      // Find a linearized idx using the size and offsets of the extract op.
+      auto extractedPos = llvm::to_vector_of<int64_t>(
+          llvm::reverse(extractOp.getStaticPosition()));
+      ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
+      int64_t scaleSrcRank = scaleSrcType.getRank();
+      SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
+      for (int64_t i = 1; i < scaleSrcRank; ++i) {
+        extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
+      }
+      int64_t idx = linearize(extractedPos, extractSizes);
+
+      // All n scales (where n is the total number of scales) must now be
+      // extracted in chunks of 4 elements. This is done by dividing the
+      // original vector of scales into groups of 4 elements
+      // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
+      // scale at a particular index are now replaced with an extraction
+      // of the entire group of 4 elements to which that index belongs.
+      //
+      // If the number of scales happens to be indivisible by 4, extract
+      // the remaining n - m scales in a chunk of 4 elements starting at
+      // offset n - 4.
+      int64_t offset = idx - (idx % 4);
+      int64_t opsel = idx - offset;
+      int64_t size = 4l;
+      // Accomdate remaining elements in the case of non-4-divisible vectors.
+      if (numElements - offset < size) {
+        opsel = size - (numElements - idx);
+        offset = numElements - 4l;
+      }
+      Type scaleSrcElemType = scaleSrcType.getElementType();
+      auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
+                                        scaleSrcElemType);
+      Value newScaleSrc =
+          vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
+      auto extract = vector::ExtractStridedSliceOp::create(
+          rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
+          ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1});
+      rewriter.modifyOpInPlace(op, [&] {
+        op->setOperand(opIdx, extract);
+        setOpsel(opIdx, opsel);
+      });
+    }
+    return success();
+  }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<PackScales>(context);
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
   MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
+  MLIRVectorDialect
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRMemRefUtils

diff  --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..52d3275dab43b 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,88 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
     : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma_less_than_4
+// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
+// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
+// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
+// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
+func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU>
+  %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
+  %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma_ugly_shapes
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+
+  // idx = 160 => opsel = 3 (last idx of last 4 bytes)
+  %scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 159 => opsel = 3
+  %scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 158 => opsel = 2
+  %scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 157 => opsel = 1
+  %scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+
+  %sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+
+  %sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  
+  %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}


        


More information about the Mlir-commits mailing list