[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 17 15:48:03 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/155951

>From 2b6d91708127fb3da9f648c778037981657c3ad1 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 28 Aug 2025 17:01:09 -0700
Subject: [PATCH 1/4] Add packing of scales for ScaledMFMAOp

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |   1 +
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 142 ++++++++++++++++++
 mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt     |   1 +
 mlir/test/Dialect/AMDGPU/canonicalize.mlir    |  25 +++
 4 files changed, 169 insertions(+)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
     attr-dict
     `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
+  let hasCanonicalizer = 1;
 }
 #endif // AMDGPU
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 11a40d663a201..4107ec53a0988 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <cstdint>
 #include <limits>
 #include <optional>
 
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // If this use of a scale has a non zero opsel, packing has already been
+    // done.
+    auto checkIfUnpackable = [&](OpOperand &op) {
+      if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+        switch (op.getOperandNumber()) {
+        case 3:
+          return smfma.getScalesIdxA() != 0;
+          break;
+        case 4:
+          return smfma.getScalesIdxB() != 0;
+          break;
+        default:
+          return true;
+          break;
+        }
+      }
+    };
+
+    auto setOpsel = [&](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        return op.setScalesIdxA(val);
+        break;
+      case 4:
+        return op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // Obtain flat index from offsets and shape.
+    auto getIdxFromExtract = [](vector::ExtractOp op) {
+      ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+      int cumul = 1;
+      int idx = 0;
+      for (auto [offset, size] :
+           reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+        idx += offset * cumul;
+        cumul *= size;
+      }
+      return idx;
+    };
+
+    // Obtain offsets for new shape from flat index.
+    auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+      SmallVector<int64_t> res;
+      ShapedType shapedty = static_cast<ShapedType>(ty);
+      int64_t numElements = shapedty.getNumElements();
+      for (auto size : shapedty.getShape()) {
+        numElements /= size;
+        res.push_back(idx / numElements);
+        idx -= (idx / numElements) * size;
+      }
+      return res;
+    };
+
+    // For every scale operand of this ScaledMFMAOp, if the scale follows the
+    // following pattern:
+    //
+    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
+    // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0] * ...
+    //
+    // rewrite to:
+    //
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
+    // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0-3] * ...
+    //
+    // This creates duplicate shape_casts for every use but these will be removed in CSE.
+    for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+      auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+      if (!insertOp) {
+        return failure();
+      }
+      if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+        return failure();
+      }
+
+      auto extractOp =
+          insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+      if (!extractOp) {
+        return failure();
+      }
+
+      Value scaleSrc = extractOp.getOperand(0);
+      auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+      if (!stype) {
+        return failure();
+      }
+      // We do not handle dynamic dims yet, assume that the input is padded to
+      // a static shape now.
+      if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+                       [&](int64_t i) { return stype.isDynamicDim(i); })) {
+        return failure();
+      }
+
+      int64_t numElements = stype.getNumElements();
+      if (numElements <= 4) {
+        return failure();
+      }
+
+      Type newSrcType = VectorType::get(
+          SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+      Value newScaleSrc =
+          rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+      int64_t idx = getIdxFromExtract(extractOp);
+      SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+      auto scaleTy = VectorType::get({4}, stype.getElementType());
+      Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+          SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+      Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
+      op.setOperand(opIdx, scale);
+      setOpsel(opIdx, offsets[1]);
+    }
+    return success();
+  }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<PackScales>(context);
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
   MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
+  MLIRVectorDialect
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRMemRefUtils
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..75cbf29c95f29 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
     : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}

>From 3873edac6f205ac98808103bfdb1251eedfadf99 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 17 Sep 2025 14:36:47 -0500
Subject: [PATCH 2/4] PR review round 0

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 26 ++++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4107ec53a0988..2e3f95651902e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -653,26 +653,24 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
         switch (op.getOperandNumber()) {
         case 3:
           return smfma.getScalesIdxA() != 0;
-          break;
         case 4:
           return smfma.getScalesIdxB() != 0;
-          break;
         default:
-          return true;
           break;
         }
       }
+      return true;
     };
 
     auto setOpsel = [&](unsigned idx, int64_t val) {
       switch (idx) {
       case 3:
-        return op.setScalesIdxA(val);
+        op.setScalesIdxA(val);
         break;
       case 4:
-        return op.setScalesIdxB(val);
+        op.setScalesIdxB(val);
         break;
-      default:
+      default: 
         break;
       }
     };
@@ -695,7 +693,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
       SmallVector<int64_t> res;
       ShapedType shapedty = static_cast<ShapedType>(ty);
       int64_t numElements = shapedty.getNumElements();
-      for (auto size : shapedty.getShape()) {
+      for (unsigned size : shapedty.getShape()) {
         numElements /= size;
         res.push_back(idx / numElements);
         idx -= (idx / numElements) * size;
@@ -706,17 +704,19 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
     // For every scale operand of this ScaledMFMAOp, if the scale follows the
     // following pattern:
     //
-    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
-    // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
-    // amdgpu.scaled_mfma(%scale[0] * ...
+    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
+    // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
+    // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
     //
     // rewrite to:
     //
-    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
-    // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
+    // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
+    // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
     // amdgpu.scaled_mfma(%scale[0-3] * ...
     //
-    // This creates duplicate shape_casts for every use but these will be removed in CSE.
+    // This creates duplicate shape_casts for every use but these will be
+    // removed in CSE.
     for (auto opIdx : SmallVector<int64_t>({3, 4})) {
       auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
       if (!insertOp) {

>From 2404d99880984e087cbe9d62faf1d3a4205effd8 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 17 Sep 2025 15:31:56 -0500
Subject: [PATCH 3/4] PR review round 1

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 33 +++++++++++---------
 1 file changed, 19 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2e3f95651902e..1cc800ec92090 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -670,7 +670,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
       case 4:
         op.setScalesIdxB(val);
         break;
-      default: 
+      default:
         break;
       }
     };
@@ -678,8 +678,8 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
     // Obtain flat index from offsets and shape.
     auto getIdxFromExtract = [](vector::ExtractOp op) {
       ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
-      int cumul = 1;
-      int idx = 0;
+      int64_t cumul = 1;
+      int64_t idx = 0;
       for (auto [offset, size] :
            reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
         idx += offset * cumul;
@@ -720,33 +720,37 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
     for (auto opIdx : SmallVector<int64_t>({3, 4})) {
       auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
       if (!insertOp) {
-        return failure();
+        return rewriter.notifyMatchFailure(op,
+                                           "defining op not a vector.insert");
       }
       if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
-        return failure();
+        return rewriter.notifyMatchFailure(op,
+                                           "some scaled mfma's already packed");
       }
 
       auto extractOp =
           insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
       if (!extractOp) {
-        return failure();
+        return rewriter.notifyMatchFailure(op,
+                                           "defining op not a vector.extract");
       }
 
       Value scaleSrc = extractOp.getOperand(0);
-      auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+      auto stype = dyn_cast<VectorType>(scaleSrc.getType());
       if (!stype) {
-        return failure();
+        return rewriter.notifyMatchFailure(op, "not a shaped type");
       }
       // We do not handle dynamic dims yet, assume that the input is padded to
       // a static shape now.
-      if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
-                       [&](int64_t i) { return stype.isDynamicDim(i); })) {
-        return failure();
+      if (!stype.hasStaticShape()) {
+        return rewriter.notifyMatchFailure(op,
+                                           "dynamic dims not yet supported");
       }
 
       int64_t numElements = stype.getNumElements();
-      if (numElements <= 4) {
-        return failure();
+      if (numElements <= 4 || !(numElements % 4)) {
+        return rewriter.notifyMatchFailure(
+            op, "no packing if # of scales less than or indivisible by four");
       }
 
       Type newSrcType = VectorType::get(
@@ -760,7 +764,8 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
           loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
           SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
       Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
-      op.setOperand(opIdx, scale);
+      rewriter.modifyOpInPlace(
+          op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); });
       setOpsel(opIdx, offsets[1]);
     }
     return success();

>From 970aa1a9e55d00e5d9df108f4a0015a87234bd88 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 17 Sep 2025 17:47:41 -0500
Subject: [PATCH 4/4] Perform packing for inputs with shapes non-divisible by 4

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 56 +++++------
 mlir/test/Dialect/AMDGPU/canonicalize.mlir   | 97 +++++++++++++++++++-
 2 files changed, 118 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 1cc800ec92090..e04a1d75724fb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #include <cstdint>
@@ -688,36 +689,22 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
       return idx;
     };
 
-    // Obtain offsets for new shape from flat index.
-    auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
-      SmallVector<int64_t> res;
-      ShapedType shapedty = static_cast<ShapedType>(ty);
-      int64_t numElements = shapedty.getNumElements();
-      for (unsigned size : shapedty.getShape()) {
-        numElements /= size;
-        res.push_back(idx / numElements);
-        idx -= (idx / numElements) * size;
-      }
-      return res;
-    };
-
     // For every scale operand of this ScaledMFMAOp, if the scale follows the
     // following pattern:
-    //
-    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
-    // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
-    // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
+    // (f8 here means f8E8M0FNU)
+    // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
+    // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
+    // amdgpu.scaled_mfma(%scale[0] * ...
     //
     // rewrite to:
     //
-    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
-    // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
-    // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
+    // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
     // amdgpu.scaled_mfma(%scale[0-3] * ...
     //
     // This creates duplicate shape_casts for every use but these will be
     // removed in CSE.
-    for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+    for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
       auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
       if (!insertOp) {
         return rewriter.notifyMatchFailure(op,
@@ -738,7 +725,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
       Value scaleSrc = extractOp.getOperand(0);
       auto stype = dyn_cast<VectorType>(scaleSrc.getType());
       if (!stype) {
-        return rewriter.notifyMatchFailure(op, "not a shaped type");
+        return rewriter.notifyMatchFailure(op, "not a vector type");
       }
       // We do not handle dynamic dims yet, assume that the input is padded to
       // a static shape now.
@@ -748,25 +735,32 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
       }
 
       int64_t numElements = stype.getNumElements();
-      if (numElements <= 4 || !(numElements % 4)) {
+      if (numElements <= 4) {
         return rewriter.notifyMatchFailure(
-            op, "no packing if # of scales less than or indivisible by four");
+            op, "no packing if # of scales less than four");
+      }
+      int64_t idx = getIdxFromExtract(extractOp);
+      int64_t offset = idx - (idx % 4);
+      int64_t size = std::min(4l, numElements - offset);
+      int64_t opsel = idx - offset;
+      if (size != 4l) {
+        opsel += 4l - size;
+        offset = numElements - 4l;
+        size = 4l;
       }
 
-      Type newSrcType = VectorType::get(
-          SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+      Type newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
+                                        stype.getElementType());
       Value newScaleSrc =
           rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
-      int64_t idx = getIdxFromExtract(extractOp);
-      SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
       auto scaleTy = VectorType::get({4}, stype.getElementType());
       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
-          SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+          loc, newScaleSrc, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{size},
+          ArrayRef<int64_t>{1});
       Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
       rewriter.modifyOpInPlace(
           op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); });
-      setOpsel(opIdx, offsets[1]);
+      setOpsel(opIdx, opsel);
     }
     return success();
   }
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 75cbf29c95f29..8179d8e0ce513 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -163,11 +163,11 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
 // -----
 
 // CHECK-LABEL: func @scaled_mfma
-// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
-// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
 // CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
-// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
-// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
 // CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
 func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
@@ -184,3 +184,92 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
   %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma_less_than_4
+// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
+// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
+// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
+// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
+func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU>
+  %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
+  %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0 : vector<4xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma_ugly_shapes
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA_0_0 = vector.extract %scalesA[0, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_1 = vector.extract %scalesA[1, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_2 = vector.extract %scalesA[2, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_3 = vector.extract %scalesA[3, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+  %scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
+
+  // idx = 138 + 8 = 146 => opsel = 2
+  %scaleB_6_8 = vector.extract %scalesB[6, 8] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 147 => opsel = 3
+  %scaleB_6_9 = vector.extract %scalesB[6, 9] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 148 => opsel = 0
+  %scaleB_6_10 = vector.extract %scalesB[6, 10] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 149 => opsel = 1
+  %scaleB_6_11 = vector.extract %scalesB[6, 11] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 160 => opsel = 3 (last idx of last 4 bytes)
+  %scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 159 => opsel = 3
+  %scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 158 => opsel = 2
+  %scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+  // idx = 157 => opsel = 1
+  %scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
+
+  %sA_0_0 = vector.insert %scaleA_0_0, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_1 = vector.insert %scaleA_0_1, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_2 = vector.insert %scaleA_0_2, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_3 = vector.insert %scaleA_0_3, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+
+  %sB_6_8 = vector.insert %scaleB_6_8, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_9 = vector.insert %scaleB_6_9, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_10 = vector.insert %scaleB_6_10, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_11 = vector.insert %scaleB_6_11, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  
+  %res_0 = amdgpu.scaled_mfma(%sA_0_0[0] * %opA) * (%sB_6_8[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_1 = amdgpu.scaled_mfma(%sA_0_1[0] * %opA) * (%sB_6_9[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_2 = amdgpu.scaled_mfma(%sA_0_2[0] * %opA) * (%sB_6_10[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_3 = amdgpu.scaled_mfma(%sA_0_3[0] * %opA) * (%sB_6_11[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0, %res_1, %res_2, %res_3, %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}



More information about the Mlir-commits mailing list