[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 27 11:40:52 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --diff_from_common_commit
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f4034f44d..731e33c82 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1212,12 +1212,11 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands(
- {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
- allowBf16),
- packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
- allowBf16),
- adaptor.getDestC()});
+ loweredOp.addOperands({packSmallFloatVectorOperand(
+ rewriter, loc, adaptor.getSourceA(), allowBf16),
+ packSmallFloatVectorOperand(
+ rewriter, loc, adaptor.getSourceB(), allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1401,13 +1400,14 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
bool is32x16 = (m == 32 && n == 16 && k == 128);
if (m == 16 && n == 16 && k == 128) {
- intrinsicName = isScale16
- ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
- : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+ intrinsicName =
+ isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
} else if (is32x16) {
- intrinsicName = isScale16
- ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
- : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+ intrinsicName =
+ isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
} else {
return op.emitOpError("unsupported scaled_wmma dimensions: ")
<< m << "x" << n << "x" << k;
@@ -1417,29 +1417,29 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
// The f4 variant does not have fmtA and fmtB attributes
if (!is32x16) {
- attrs.push_back(rewriter.getNamedAttr("fmtA",
- rewriter.getI32IntegerAttr(*aFmtCode)));
- attrs.push_back(rewriter.getNamedAttr("fmtB",
- rewriter.getI32IntegerAttr(*bFmtCode)));
+ attrs.push_back(
+ rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
+ attrs.push_back(
+ rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
}
// Add modifier attributes - modC and reuse flags default to 0/false
- attrs.push_back(rewriter.getNamedAttr("reuseA",
- rewriter.getBoolAttr(false)));
- attrs.push_back(rewriter.getNamedAttr("reuseB",
- rewriter.getBoolAttr(false)));
- attrs.push_back(rewriter.getNamedAttr("modC",
- rewriter.getI16IntegerAttr(0)));
+ attrs.push_back(
+ rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
+ attrs.push_back(
+ rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
+ attrs.push_back(
+ rewriter.getNamedAttr("modC", rewriter.getI16IntegerAttr(0)));
// Scale type/format parameters from the operation
- attrs.push_back(rewriter.getNamedAttr("scaleAType",
- rewriter.getI32IntegerAttr(op.getScaleAType())));
- attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
- rewriter.getI32IntegerAttr(op.getFmtScaleA())));
- attrs.push_back(rewriter.getNamedAttr("scaleBType",
- rewriter.getI32IntegerAttr(op.getScaleBType())));
- attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
- rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "scaleAType", rewriter.getI32IntegerAttr(op.getScaleAType())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "fmtScaleA", rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "scaleBType", rewriter.getI32IntegerAttr(op.getScaleBType())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "fmtScaleB", rewriter.getI32IntegerAttr(op.getFmtScaleB())));
// Convert typed float vectors to packed i32 format if needed
Value sourceA =
@@ -2428,10 +2428,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
- ScaledExtPacked816OpLowering,
- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering, TransposeLoadOpLowering,
- AMDGPUPermlaneLowering>(converter, chipset);
+ ScaledExtPacked816OpLowering, ScaledExtPackedOpLowering,
+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 87bd19032..fceded5c2 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -450,7 +450,7 @@ LogicalResult ScaledWMMAOp::verify() {
auto sourceAType = cast<VectorType>(getSourceA().getType());
auto sourceBType = cast<VectorType>(getSourceB().getType());
auto destType = cast<VectorType>(getDestC().getType());
-
+
// Validate output type is F32
if (!destType.getElementType().isF32())
return emitOpError("destination must have f32 element type");
@@ -459,10 +459,10 @@ LogicalResult ScaledWMMAOp::verify() {
Type aElemType = sourceAType.getElementType();
Type bElemType = sourceBType.getElementType();
- bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
- aElemType.isFloat(8);
- bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
- bElemType.isFloat(8);
+ bool aIsSmallFloat =
+ aElemType.isFloat(4) || aElemType.isFloat(6) || aElemType.isFloat(8);
+ bool bIsSmallFloat =
+ bElemType.isFloat(4) || bElemType.isFloat(6) || bElemType.isFloat(8);
if (!aIsSmallFloat || !bIsSmallFloat)
return emitOpError("source operands must have small float element types "
@@ -479,7 +479,7 @@ LogicalResult ScaledWMMAOp::verify() {
int64_t aLen = sourceAType.getNumElements();
int64_t bLen = sourceBType.getNumElements();
int64_t expectedOutLen = (m == 16) ? 4 : 8;
-
+
if (destType.getNumElements() != expectedOutLen)
return emitOpError("expected output vector of length " +
Twine(expectedOutLen) + " but got " +
``````````
</details>
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list