[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Muzammiluddin Syed
llvmlistbot at llvm.org
Thu Nov 27 13:37:25 PST 2025
================
@@ -1363,6 +1363,104 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+ ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ if (chipset < Chipset(12, 5, 0))
+ return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+ int64_t m = op.getM();
+ int64_t n = op.getN();
+ int64_t k = op.getK();
+
+ Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+ Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+ std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+ std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+ if (!aFmtCode || !bFmtCode)
+ return op.emitOpError("unsupported element types for scaled_wmma");
+
+ // Determine which intrinsic to use based on dimensions and scale type
+ StringRef intrinsicName;
+ bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+ bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+ if (m == 16 && n == 16 && k == 128) {
+ intrinsicName =
+ isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+ } else if (is32x16) {
+ intrinsicName =
+ isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+ } else {
+ return op.emitOpError("unsupported scaled_wmma dimensions: ")
+ << m << "x" << n << "x" << k;
+ }
+
+ SmallVector<NamedAttribute, 8> attrs;
+
+ // The f4 variant does not have fmtA and fmtB attributes
+ if (!is32x16) {
+ attrs.push_back(
+ rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
+ attrs.push_back(
+ rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
+ }
+
+ // Add modifier attributes - modC and reuse flags default to 0/false
----------------
Muzammiluddin-Syed-ECE wrote:
Not related to this PR but ideally, we would be able to use a tablgen'd builder function for this instead of having to create an op manually like this. But that can be a future PR.
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list