[Mlir-commits] [mlir] [MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (PR #78713)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 19 05:43:39 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-llvm
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
The current implementation of `nvvm.wgmma.mma_async` Op deduces the data type of the output matrix from the data type struct member, which can be non-intuitive, especially in cases where types like `2xf16` are packed into `i32`.
This PR addresses this issue by improving the Op to include an explicit data type for the output matrix.
The modified Op now includes an explicit data type for Matrix-D (<f16>), and looks as follows:
```
%result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ...
nvvm.wgmma.mma_async
%descA, %descB, %result,
#nvvm.shape<m = 64, n = 32, k = 16>,
D [<f16>, #nvvm.wgmma_scale_out<zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
```
---
Patch is 42.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78713.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+7-3)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+11-4)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+58-61)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+9-9)
- (modified) mlir/test/Conversion/NVVMToLLVM/invalid.mlir (+21-21)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+28-28)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..b1bd3a95068076 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1833,11 +1833,14 @@ def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
+def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>;
+def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>;
+
def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
[WGMMATypeF16, WGMMATypeTF32,
WGMMATypeU8, WGMMATypeS8,
WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3,
- WGMMATypeF8E5M2]> {
+ WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -1859,6 +1862,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
NVVM_MMAShapeAttr:$shape,
WGMMATypesAttr:$typeA,
WGMMATypesAttr:$typeB,
+ WGMMATypesAttr:$typeD,
WGMMAScaleOutAttr:$scaleD,
WGMMAScaleInAttr:$scaleA,
WGMMAScaleInAttr:$scaleB,
@@ -1868,8 +1872,8 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
);
let assemblyFormat = [{
- $descriptorA `,` $descriptorB `,` $shape `,`
- `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
+ $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,`
+ `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,`
`A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,`
`B` `[` $typeB `,` $scaleB `,` $layoutB `]`
attr-dict `:`
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..9950499817789d 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1267,10 +1267,11 @@ struct NVGPUWarpgroupMmaOpLowering
}
/// Generates WGMMATypesAttr from MLIR Type
- NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
- auto getWgmmaType = [](Type elemType) {
+ NVVM::WGMMATypesAttr generateWgmmaType(Type type,
+ bool useF32 = false) const {
+ auto getWgmmaType = [=](Type elemType) {
if (elemType.isF32() || elemType.isTF32())
- return NVVM::WGMMATypes::tf32;
+ return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
if (elemType.isF16())
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
@@ -1285,6 +1286,8 @@ struct NVGPUWarpgroupMmaOpLowering
return NVVM::WGMMATypes::s8;
if (elemType.isUnsignedInteger(8))
return NVVM::WGMMATypes::u8;
+ if (elemType.isInteger(32))
+ return NVVM::WGMMATypes::s32;
llvm_unreachable("unsupported type");
};
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
@@ -1397,6 +1400,9 @@ struct NVGPUWarpgroupMmaOpLowering
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+ Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
+ NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
+
NVVM::MMAShapeAttr shape = generateWgmmaShape();
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
@@ -1408,7 +1414,8 @@ struct NVGPUWarpgroupMmaOpLowering
return b.create<NVVM::WgmmaMmaAsyncOp>(
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
- itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+ overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..bb720603819407 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
return failure();
}
-LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
+LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
+ NVVM::WGMMATypes typeA,
NVVM::WGMMATypes typeB) {
switch (typeA) {
case NVVM::WGMMATypes::f16:
- if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ typeB == NVVM::WGMMATypes::f16)
return success();
break;
case NVVM::WGMMATypes::tf32:
- if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
+ if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
return success();
break;
case NVVM::WGMMATypes::u8:
case NVVM::WGMMATypes::s8:
- if (typeD.isInteger(32) &&
+ if (typeD == NVVM::WGMMATypes::s32 &&
(typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
return success();
break;
case NVVM::WGMMATypes::b1:
- if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
+ if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
return success();
break;
case NVVM::WGMMATypes::bf16:
- if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ typeB == NVVM::WGMMATypes::bf16)
return success();
break;
case NVVM::WGMMATypes::e4m3:
case NVVM::WGMMATypes::e5m2:
- if ((typeD.isF32() || typeD.isF16()) &&
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
(typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
return success();
break;
+ case WGMMATypes::f32:
+ case WGMMATypes::s32:
+ llvm_unreachable("unsupported input types");
+ break;
}
return failure();
}
@@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
switch (typeA) {
- case mlir::NVVM::WGMMATypes::f16:
- case mlir::NVVM::WGMMATypes::tf32:
- case mlir::NVVM::WGMMATypes::bf16:
- case mlir::NVVM::WGMMATypes::e4m3:
- case mlir::NVVM::WGMMATypes::e5m2:
+ case WGMMATypes::f16:
+ case WGMMATypes::tf32:
+ case WGMMATypes::bf16:
+ case WGMMATypes::e4m3:
+ case WGMMATypes::e5m2:
if (llvm::is_contained(allowedN, sizeN))
return success();
break;
- case mlir::NVVM::WGMMATypes::u8:
- case mlir::NVVM::WGMMATypes::s8:
- case mlir::NVVM::WGMMATypes::b1:
+ case WGMMATypes::u8:
+ case WGMMATypes::s8:
+ case WGMMATypes::b1:
if (llvm::is_contained(allowedNshort, sizeN))
return success();
+ break;
+ case WGMMATypes::f32:
+ case WGMMATypes::s32:
+ llvm_unreachable("unsupported input types");
+ break;
}
return failure();
}
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
if (!stype)
return emitOpError() << "expected results to be struct";
- Type outputType = stype.getBody().front();
int outputSize = stype.getBody().size();
+ WGMMATypes typeD = getTypeD();
+ WGMMATypes typeA = getTypeA();
+ WGMMATypes typeB = getTypeB();
+
for (Type t : stype.getBody()) {
- if (t != outputType)
+ if (t != stype.getBody().front())
return emitOpError()
<< "all elements in struct must be same type but there is " << t;
}
- if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+ if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+ typeD != WGMMATypes::s32) {
return emitOpError() << "does not support the given output type "
- << outputType;
+ << NVVM::stringifyWGMMATypes(typeD);
}
- if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
- getScaleB() == NVVM::WGMMAScaleIn::neg)) {
- return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+ if (typeD == WGMMATypes::s32 &&
+ (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
+ return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg ";
}
- mlir::NVVM::WGMMATypes typeA = getTypeA();
- mlir::NVVM::WGMMATypes typeB = getTypeB();
- if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
- return emitOpError() << outputType
+ if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
+ return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
<< " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
<< NVVM::stringifyWGMMATypes(typeB)
<< ", it is not supported.";
@@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
}
// Check transpose (only available for f16/bf16)
- if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
- typeA != mlir::NVVM::WGMMATypes::bf16) &&
+ if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
getLayoutB() == mlir::NVVM::MMALayout::col)) {
return emitOpError()
@@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
<< " for input types " << stringifyWGMMATypes(typeA) << " and "
<< stringifyWGMMATypes(typeB)
<< " requires transpose. However, this is only supported for: "
- << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
- << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
+ << stringifyMMATypes(MMATypes::f16) << " and "
+ << stringifyMMATypes(MMATypes::bf16);
}
// Check result registers
- int expectedOutput;
- if (outputType.isF32() || outputType.isInteger(32))
+ int expectedOutput = 0;
+ if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
expectedOutput = getShape().getN() / 2;
- if (outputType.isF16())
+ if (typeD == WGMMATypes::f16)
expectedOutput = getShape().getN() / 4;
if (outputSize != expectedOutput) {
return emitOpError() << "results " << expectedOutput
<< ", however output struct has " << outputSize
<< " elements";
}
- // Check satfinite (only availalbe for s32 accumulator)
- if (!outputType.isInteger(32) &&
+ // Check satfinite (only available for s32 accumulator)
+ if (typeD != WGMMATypes::s32 &&
getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite) {
return emitOpError()
<< " `satfinite` can be only used with s32 accumulator, however "
"the current accumulator is "
- << outputType;
+ << NVVM::stringifyWGMMATypes(typeD);
}
return success();
@@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
- bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
- getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+ bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
- Value outValue = getResults() ? getResults() : getInouts();
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
- Type outputType = stype.getBody().front();
- std::string outputTypeName;
- if (outputType.isF16())
- outputTypeName = "f16";
- else if (outputType.isF32())
- outputTypeName = "f32";
- else if (outputType.isInteger(32))
- outputTypeName = "s32";
- else
- assert(false && "unsupported output type");
+ StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
- int expectedOutputRegisters;
- if (outputType.isF32() || outputType.isInteger(32))
- expectedOutputRegisters = getShape().getN() / 2;
- if (outputType.isF16())
+ int expectedOutputRegisters = 0;
+ if (getTypeD() == WGMMATypes::f16)
expectedOutputRegisters = getShape().getN() / 4;
+ else
+ expectedOutputRegisters = getShape().getN() / 2;
std::string ptx;
llvm::raw_string_ostream ss(ptx);
@@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << " $" << (regCnt) << ","
<< " $" << (regCnt + 1) << ","
<< " p";
- if (!outputType.isInteger(32)) {
+ if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
// Don't add transpose parameters unless needed.
@@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
- Value outValue = getResults() ? getResults() : getInouts();
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
- Type outputType = stype.getBody().front();
- bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
- getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+ bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
if (getResults())
asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
if (getInouts())
@@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
mlir::NVVM::PTXRegisterMod::Read});
- if (!outputType.isInteger(32)) {
+ if (getTypeD() != WGMMATypes::s32) {
asmValues.push_back(
{makeConstantI32(rewriter,
getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..3ca970f412833f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -880,41 +880,41 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S2]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64
-// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], %[[S4]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64
// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64
// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64
// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64
-// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], %[[S9]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64
// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64
// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64
-// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], %[[S14]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64
-// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f1...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/78713
More information about the Mlir-commits
mailing list