[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Nov 30 06:26:26 PST 2025
================
@@ -1363,6 +1373,136 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+ ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ if (chipset < Chipset(12, 5, 0))
+ return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+ int64_t m = op.getM();
+ int64_t n = op.getN();
+ int64_t k = op.getK();
+
+ Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+ Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+ std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+ std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+ if (!aFmtCode || !bFmtCode)
+ return op.emitOpError("unsupported element types for scaled_wmma");
+
+ // Get scale vector types and determine variant (scale vs scale16)
+ auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
+ auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
+
+ bool isScale16 = (scaleAVecType.getNumElements() == 8);
+ if (isScale16 != (scaleBVecType.getNumElements() == 8))
+ return op.emitOpError("scaleA and scaleB must have equal vector length");
+
+ // Extract scale format from element types
+ Type scaleAElemType = scaleAVecType.getElementType();
+ Type scaleBElemType = scaleBVecType.getElementType();
+
+ // Map f8 types to format codes
+ auto getScaleFormat = [](Type elemType) -> std::optional<uint32_t> {
+ if (isa<Float8E8M0FNUType>(elemType))
+ return 0;
+ if (isa<Float8E4M3FNType>(elemType))
+ return 2;
+ return std::nullopt;
+ };
----------------
kuhar wrote:
Use `TypeSwitch`. I'd also move it to a helper function -- this one is already quite long
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list