[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Sep 17 12:50:13 PDT 2025


================
@@ -631,6 +633,148 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // If this use of a scale has a non zero opsel, packing has already been
+    // done.
+    auto checkIfUnpackable = [&](OpOperand &op) {
+      if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+        switch (op.getOperandNumber()) {
+        case 3:
+          return smfma.getScalesIdxA() != 0;
+          break;
+        case 4:
+          return smfma.getScalesIdxB() != 0;
+          break;
+        default:
+          return true;
+          break;
+        }
+      }
+    };
+
+    auto setOpsel = [&](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        return op.setScalesIdxA(val);
+        break;
+      case 4:
+        return op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // Obtain flat index from offsets and shape.
+    auto getIdxFromExtract = [](vector::ExtractOp op) {
+      ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+      int cumul = 1;
----------------
krzysz00 wrote:

int64_t for things holding shapes

https://github.com/llvm/llvm-project/pull/155951


More information about the Mlir-commits mailing list