[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 18 08:28:42 PDT 2025
================
@@ -631,6 +635,139 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ScaledMFMAOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto setOpsel = [&](unsigned idx, int64_t val) {
+ switch (idx) {
+ case 3:
+ op.setScalesIdxA(val);
+ break;
+ case 4:
+ op.setScalesIdxB(val);
+ break;
+ default:
+ break;
+ }
+ };
+
+ // For every scale operand of this ScaledMFMAOp, if the scale is produced by
+ // the extraction of a single scale from some vector, then attempt to
+ // extract 4 values from that vector instead.
+ //
+ // Example: (f8 here means f8E8M0FNU)
+ // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
+ // %scale = vector.insert %unit, ... : f8 into vector<4xf8>
+ // amdgpu.scaled_mfma(%scale[0] * ...
+ //
+ // rewrite to:
+ //
+ // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
+ // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
+ // amdgpu.scaled_mfma(%scale[0-3] * ...
+ //
+ // This creates duplicate shape_casts for every use but these will be
+ // removed in CSE.
+ for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
+ auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+ if (!insertOp) {
+ return rewriter.notifyMatchFailure(op,
+ "defining op not a vector.insert");
+ }
+ // if the extracted value is not a single scalar, then it has been packed.
+ if (dyn_cast<VectorType>(insertOp.getValueToStore().getType())) {
+ return rewriter.notifyMatchFailure(
+ op, "scaled mfma operand already packed");
+ }
+
+ auto extractOp =
+ insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
+ if (!extractOp) {
+ return rewriter.notifyMatchFailure(op,
+ "defining op not a vector.extract");
+ }
+
+ Value scaleSrc = extractOp.getOperand(0);
+ auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
+ if (!scaleSrcType) {
+ return rewriter.notifyMatchFailure(op, "not a vector type");
+ }
+
+ // We do not handle dynamic dims yet, assume that the input is padded to
+ // a static shape now.
+ if (!scaleSrcType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "dynamic dims not yet supported");
+ }
+
+ int64_t numElements = scaleSrcType.getNumElements();
+ if (numElements <= 4) {
+ return rewriter.notifyMatchFailure(
+ op, "no packing if # of scales less than four");
+ }
+
+ // Find a linearized idx using the size and offsets of the extract op
+ ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
+ int64_t scaleSrcRank = scaleSrcType.getRank();
+ SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
+ SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
+ std::reverse(extractedPos.begin(), extractedPos.end());
+ for (int64_t i = 1; i < scaleSrcRank; i++) {
----------------
kuhar wrote:
```suggestion
for (int64_t i = 1; i < scaleSrcRank; ++i) {
```
https://llvm.org/docs/CodingStandards.html#prefer-preincrement
https://github.com/llvm/llvm-project/pull/155951
More information about the Mlir-commits
mailing list