[Mlir-commits] [mlir] 105ce58 - [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (#137498)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 2 16:07:21 PDT 2025
Author: Muzammil
Date: 2025-05-02T18:07:17-05:00
New Revision: 105ce585d35eff433031d3edce977eba97eeb6ff
URL: https://github.com/llvm/llvm-project/commit/105ce585d35eff433031d3edce977eba97eeb6ff
DIFF: https://github.com/llvm/llvm-project/commit/105ce585d35eff433031d3edce977eba97eeb6ff.diff
LOG: [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (#137498)
Create a wrapper around the new scaled MFMAs that operate on specific
element types and tile sizes.
See [Issue](https://github.com/iree-org/iree/issues/20616).
---------
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
Added:
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..0cebecee0390f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -687,6 +687,10 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
+// scaled_mfma
+def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
+ VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
+def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
[4, 8, 16],
@@ -804,7 +808,7 @@ def AMDGPU_GatherToLDSOp :
TypeAttr:$transferType
)>,
Results<(outs)> {
- let summary = "MLIR wrapper for CDNA mfma instructions";
+ let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
let description = [{
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
@@ -830,4 +834,52 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_ScaledMFMAOp :
+ AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ I32Attr:$m,
+ I32Attr:$n,
+ I32Attr:$k,
+ ScaledMFMAInTypes:$sourceA,
+ ScaledMFMAInTypes:$sourceB,
+ ScaledMFMAOutTypes:$destC,
+ AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesA,
+ AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesB,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxA,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxB
+ )>,
+ Results<(outs ScaledMFMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA scaled mfma instructions";
+ let description = [{
+ The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+ for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+ multiple outer products in order to allow fast matrix multiplication.
+
+ The wrapper will select an appropriate `mfma` instruction, if one is available,
+ based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+ types of the source and destination arguments.
+
+ Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+ intrinsics that take an integer type of width `4K`. For example,
+ one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
+ logically takes 4 i8s but whose intrinsics are specified to take an i32.
+ In these cases, the bytes in the vector will be concatenated in little-endian
+ order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+ This wrapper takes inspiration from `amdgpu.mfma`, but has some key
diff erences:
+ - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
+ fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
+ size.
+ - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
+ are omitted from this wrapper.
+ - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
+ double-precision operations on gfx94x and so are not included here.
+ }];
+ let assemblyFormat = [{
+ `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
+ attr-dict
+ `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
+ }];
+}
#endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..6e596485cbb58 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -23,6 +23,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include <optional>
namespace mlir {
@@ -528,6 +529,25 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
+/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
+/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+///
+/// Specifically:
+/// 1. If `input` is a i8 value, zero extend it to i32
+/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+///
+/// Note that the type of `input` has already been LLVM type converted:
+/// therefore 8-bit and smaller floats are represented as their corresponding
+/// `iN` integers.
+static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
+ Type inputType = input.getType();
+ Type outputType = rewriter.getI32Type();
+ if (auto intType = dyn_cast<IntegerType>(inputType))
+ return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
+ return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+}
+
/// Push an input operand. If it is a float type, nothing to do. If it is
/// an integer type, then we need to also push its signdness (1 for signed, 0
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
@@ -833,6 +853,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
mfma.getBlocks(), chipset);
}
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+ return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+ smfma.getSourceB().getType(),
+ smfma.getDestC().getType(), smfma.getM(),
+ smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -954,6 +982,52 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+ ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
+
+ if (chipset.majorVersion != 9 || chipset < kGfx950)
+ return op->emitOpError("scaled MFMA only supported on gfx908+");
+ std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+ maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+ if (!maybeScaledIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching scaled MFMA size on given chipset");
+
+ auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(intrinsicOutType);
+ loweredOp.addOperands(
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ adaptor.getDestC()});
+ Value scalesIdxA =
+ createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
+ Value scalesIdxB =
+ createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
+ loweredOp.addOperands(
+ {createI32Constant(rewriter, loc, aTypeCode),
+ createI32Constant(rewriter, loc, bTypeCode),
+ /*scales idx A=*/scalesIdxA,
+ /*scales A*/
+ castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
+ /*scales idx B=*/scalesIdxB,
+ /*scales B*/
+ castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
+ Value lowered = rewriter.create(loweredOp)->getResult(0);
+ rewriter.replaceOp(op, lowered);
+ return success();
+ }
+};
+
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1548,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
- MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering>(converter, chipset);
+ MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+ ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+ chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..52a5d39f668c6 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,54 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
func.return
}
+
+// CHECK-LABEL: func @scaled_mfma_to_rocdl(
+// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+ %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+ %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+ %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
+ %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
+
+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
+ // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.mlir.constant(3 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.mlir.constant(4 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
+
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 16b3193d270cb..188cfcc4eb38b 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -164,3 +164,10 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
func.return %0 : f32
}
+
+// CHECK-LABEL: func @scaled_mfma
+func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // CHECK: amdgpu.scaled_mfma
+ %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
+}
More information about the Mlir-commits
mailing list