[Mlir-commits] [mlir] [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (PR #137498)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Apr 29 23:52:29 PDT 2025
================
@@ -954,6 +964,50 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+ ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
+
+ if (chipset.majorVersion != 9 || chipset < kGfx950)
+ return op->emitOpError("scaled MFMA only supported on gfx908+");
+ std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+ maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+ if (!maybeScaledIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching scaled MFMA size on given chipset");
+
+ auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(intrinsicOutType);
+ loweredOp.addOperands(
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ adaptor.getDestC()});
+ Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
+ Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
+ loweredOp.addOperands(
+ {createI32Constant(rewriter, loc, aTypeCode),
+ createI32Constant(rewriter, loc, bTypeCode),
+ /*scales A*/
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
----------------
krzysz00 wrote:
Huh, yeah, that will just so happen to do the right thing for scalar inputs, I think
https://github.com/llvm/llvm-project/pull/137498
More information about the Mlir-commits
mailing list