[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Sep 17 12:50:14 PDT 2025


================
@@ -631,6 +633,148 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // If this use of a scale has a non zero opsel, packing has already been
+    // done.
+    auto checkIfUnpackable = [&](OpOperand &op) {
+      if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+        switch (op.getOperandNumber()) {
+        case 3:
+          return smfma.getScalesIdxA() != 0;
+          break;
+        case 4:
+          return smfma.getScalesIdxB() != 0;
+          break;
+        default:
+          return true;
+          break;
+        }
+      }
+    };
+
+    auto setOpsel = [&](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        return op.setScalesIdxA(val);
+        break;
+      case 4:
+        return op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // Obtain flat index from offsets and shape.
+    auto getIdxFromExtract = [](vector::ExtractOp op) {
+      ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+      int cumul = 1;
+      int idx = 0;
+      for (auto [offset, size] :
+           reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+        idx += offset * cumul;
+        cumul *= size;
+      }
+      return idx;
+    };
+
+    // Obtain offsets for new shape from flat index.
+    auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+      SmallVector<int64_t> res;
+      ShapedType shapedty = static_cast<ShapedType>(ty);
+      int64_t numElements = shapedty.getNumElements();
+      for (auto size : shapedty.getShape()) {
+        numElements /= size;
+        res.push_back(idx / numElements);
+        idx -= (idx / numElements) * size;
+      }
+      return res;
+    };
+
+    // For every scale operand of this ScaledMFMAOp, if the scale follows the
+    // following pattern:
+    //
+    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
+    // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
+    // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
+    //
+    // rewrite to:
+    //
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
+    // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
+    // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0-3] * ...
+    //
+    // This creates duplicate shape_casts for every use but these will be
+    // removed in CSE.
+    for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+      auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+      if (!insertOp) {
+        return failure();
+      }
+      if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+        return failure();
+      }
+
+      auto extractOp =
+          insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+      if (!extractOp) {
+        return failure();
+      }
+
+      Value scaleSrc = extractOp.getOperand(0);
+      auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+      if (!stype) {
+        return failure();
+      }
+      // We do not handle dynamic dims yet, assume that the input is padded to
+      // a static shape now.
+      if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+                       [&](int64_t i) { return stype.isDynamicDim(i); })) {
+        return failure();
+      }
+
+      int64_t numElements = stype.getNumElements();
+      if (numElements <= 4) {
+        return failure();
+      }
+
+      Type newSrcType = VectorType::get(
----------------
krzysz00 wrote:

This won't work if the number of elements isn't evenly divisible by 4?

https://github.com/llvm/llvm-project/pull/155951


More information about the Mlir-commits mailing list