[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Dec 1 06:26:55 PST 2025
================
@@ -1404,93 +1427,74 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
if (!aFmtCode || !bFmtCode)
return op.emitOpError("unsupported element types for scaled_wmma");
- // Get scale vector types and determine variant (scale vs scale16)
+ // Get scale vector types and determine variant (scale vs scale16).
auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
- bool isScale16 = (scaleAVecType.getNumElements() == 8);
- if (isScale16 != (scaleBVecType.getNumElements() == 8))
+ if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
return op.emitOpError("scaleA and scaleB must have equal vector length");
- // Extract scale format from element types
+ // Extract scale format from element types.
Type scaleAElemType = scaleAVecType.getElementType();
Type scaleBElemType = scaleBVecType.getElementType();
- // Map f8 types to format codes
- auto getScaleFormat = [](Type elemType) -> std::optional<uint32_t> {
- if (isa<Float8E8M0FNUType>(elemType))
- return 0;
- if (isa<Float8E4M3FNType>(elemType))
- return 2;
- return std::nullopt;
- };
-
- std::optional<uint32_t> scaleAFmt = getScaleFormat(scaleAElemType);
- std::optional<uint32_t> scaleBFmt = getScaleFormat(scaleBElemType);
+ std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
+ std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
if (!scaleAFmt || !scaleBFmt)
return op.emitOpError("unsupported scale element types");
- // Determine which intrinsic to use based on dimensions
- StringRef intrinsicName;
- bool is32x16 = (m == 32 && n == 16 && k == 128);
-
- if (m == 16 && n == 16 && k == 128) {
- intrinsicName =
- isScale16
- ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
- : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
- } else if (is32x16) {
- intrinsicName =
- isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
- : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
- } else {
+ // Determine which intrinsic to use based on dimensions.
+ bool isScale16 = (scaleAVecType.getNumElements() == 8);
+ std::optional<StringRef> intrinsicName =
+ getScaledWmmaIntrinsicName(m, n, k, isScale16);
+ if (!intrinsicName)
return op.emitOpError("unsupported scaled_wmma dimensions: ")
<< m << "x" << n << "x" << k;
- }
SmallVector<NamedAttribute, 8> attrs;
- // The f4 variant does not have fmtA and fmtB attributes
+ // The f4 variant does not have fmtA and fmtB attributes.
+ bool is32x16 = (m == 32 && n == 16 && k == 128);
if (!is32x16) {
- attrs.push_back(
+ attrs.emplace_back(
rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
- attrs.push_back(
+ attrs.emplace_back(
rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
----------------
kuhar wrote:
You don't need to call `rewriter.getNamedAttr` in any of these
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list