[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Sep 18 12:43:06 PDT 2025


================
@@ -631,6 +635,139 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto setOpsel = [&op](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        op.setScalesIdxA(val);
+        break;
+      case 4:
+        op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // For every scale operand of this ScaledMFMAOp, if the scale is produced by
+    // the extraction of a single scale from some vector, then attempt to
+    // extract 4 values from that vector instead.
+    //
+    // Example: (f8 here means f8E8M0FNU)
+    // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
+    // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
+    // amdgpu.scaled_mfma(%scale[0] * ...
+    //
+    // rewrite to:
+    //
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
+    // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
+    // amdgpu.scaled_mfma(%scale[0-3] * ...
+    //
+    // This creates duplicate shape_casts for every use but these will be
+    // removed in CSE.
+    for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
+      auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+      if (!insertOp) {
+        return rewriter.notifyMatchFailure(op,
+                                           "defining op not a vector.insert");
+      }
+      // If the extracted value is not a single scalar, then it has been packed.
+      if (isa<VectorType>(insertOp.getValueToStore().getType())) {
+        return rewriter.notifyMatchFailure(
+            op, "scaled mfma operand already packed");
+      }
+
+      auto extractOp =
+          insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
+      if (!extractOp) {
+        return rewriter.notifyMatchFailure(op,
+                                           "defining op not a vector.extract");
+      }
+
+      Value scaleSrc = extractOp.getOperand(0);
+      auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
+      if (!scaleSrcType) {
+        return rewriter.notifyMatchFailure(op, "not a vector type");
+      }
+
+      // We do not handle dynamic dims yet, assume that the input is padded to
+      // a static shape now.
+      if (!scaleSrcType.hasStaticShape()) {
+        return rewriter.notifyMatchFailure(op,
+                                           "dynamic dims not yet supported");
+      }
+
+      int64_t numElements = scaleSrcType.getNumElements();
+      if (numElements <= 4) {
+        return rewriter.notifyMatchFailure(
+            op, "no packing if # of scales less than four");
+      }
+
+      // Find a linearized idx using the size and offsets of the extract op.
+      auto extractedPos = llvm::to_vector_of<int64_t>(
+          llvm::reverse(extractOp.getStaticPosition()));
+      ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
+      int64_t scaleSrcRank = scaleSrcType.getRank();
+      SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
+      for (int64_t i = 1; i < scaleSrcRank; ++i) {
+        extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
+      }
+      int64_t idx = linearize(extractedPos, extractSizes);
+
+      // All n scales (where n is the total number of scales) must now be
+      // extracted in chunks of 4 elements. This is done by dividing the
+      // original vector of scales into groups of 4 elements
+      // at offsets 0, 4, ..., m (where m = n/4). All extractions of a
+      // scale at a particular index are now replaced with an extraction
+      // of the entire group of 4 elements to which that index belongs.
+      //
+      // If the number of scales happens to be indivisible by 4, extract
+      // the remaining n - m scales in a chunk of 4 elements starting at
+      // offset n - 4.
+      int64_t offset = idx - (idx % 4);
+      int64_t opsel = idx - offset;
+      int64_t size = 4l;
+      // Accomdate remaining elements in the case of non-4-divisible vectors.
+      if (numElements - offset < size) {
+        opsel = size - (numElements - idx);
+        offset = numElements - 4l;
+      }
+      Type scaleSrcElemType = scaleSrcType.getElementType();
+      auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
----------------
krzysz00 wrote:

Yeah, probably a good trivial followup once commit access goes through, good catch

https://github.com/llvm/llvm-project/pull/155951


More information about the Mlir-commits mailing list