[Mlir-commits] [mlir] 12c241b - [MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (#78713)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 21 23:37:24 PST 2024
Author: Guray Ozen
Date: 2024-01-22T08:37:20+01:00
New Revision: 12c241b3654800ab708607dbc1998975c893fc14
URL: https://github.com/llvm/llvm-project/commit/12c241b3654800ab708607dbc1998975c893fc14
DIFF: https://github.com/llvm/llvm-project/commit/12c241b3654800ab708607dbc1998975c893fc14.diff
LOG: [MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (#78713)
The current implementation of `nvvm.wgmma.mma_async` Op deduces the data
type of the output matrix from the data type of struct member, which can be
non-intuitive, especially in cases where types like `2xf16` are packed
into `i32`.
This PR addresses this issue by improving the Op to include an explicit
data type for the output matrix.
The modified Op now includes an explicit data type for Matrix-D (<f16>),
and looks as follows:
```
%result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ...
nvvm.wgmma.mma_async
%descA, %descB, %result,
#nvvm.shape<m = 64, n = 32, k = 16>,
D [<f16>, #nvvm.wgmma_scale_out<zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
```
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Conversion/NVVMToLLVM/invalid.mlir
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
mlir/test/python/dialects/nvvm.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..b1bd3a95068076d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1833,11 +1833,14 @@ def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
+def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>;
+def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>;
+
def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
[WGMMATypeF16, WGMMATypeTF32,
WGMMATypeU8, WGMMATypeS8,
WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3,
- WGMMATypeF8E5M2]> {
+ WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -1859,6 +1862,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
NVVM_MMAShapeAttr:$shape,
WGMMATypesAttr:$typeA,
WGMMATypesAttr:$typeB,
+ WGMMATypesAttr:$typeD,
WGMMAScaleOutAttr:$scaleD,
WGMMAScaleInAttr:$scaleA,
WGMMAScaleInAttr:$scaleB,
@@ -1868,8 +1872,8 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
);
let assemblyFormat = [{
- $descriptorA `,` $descriptorB `,` $shape `,`
- `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
+ $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,`
+ `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,`
`A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,`
`B` `[` $typeB `,` $scaleB `,` $layoutB `]`
attr-dict `:`
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ab4dea9d5618d50..43d05b872a4fbc8 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1267,10 +1267,11 @@ struct NVGPUWarpgroupMmaOpLowering
}
/// Generates WGMMATypesAttr from MLIR Type
- NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
- auto getWgmmaType = [](Type elemType) {
+ NVVM::WGMMATypesAttr generateWgmmaType(Type type,
+ bool useF32 = false) const {
+ auto getWgmmaType = [=](Type elemType) {
if (elemType.isF32() || elemType.isTF32())
- return NVVM::WGMMATypes::tf32;
+ return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
if (elemType.isF16())
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
@@ -1285,6 +1286,8 @@ struct NVGPUWarpgroupMmaOpLowering
return NVVM::WGMMATypes::s8;
if (elemType.isUnsignedInteger(8))
return NVVM::WGMMATypes::u8;
+ if (elemType.isInteger(32))
+ return NVVM::WGMMATypes::s32;
llvm_unreachable("unsupported type");
};
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
@@ -1397,6 +1400,9 @@ struct NVGPUWarpgroupMmaOpLowering
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+ Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
+ NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
+
NVVM::MMAShapeAttr shape = generateWgmmaShape();
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
@@ -1408,7 +1414,8 @@ struct NVGPUWarpgroupMmaOpLowering
return b.create<NVVM::WgmmaMmaAsyncOp>(
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
- itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+ overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc02..a855e4b209ac5bd 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
return failure();
}
-LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
+LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
+ NVVM::WGMMATypes typeA,
NVVM::WGMMATypes typeB) {
switch (typeA) {
case NVVM::WGMMATypes::f16:
- if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ typeB == NVVM::WGMMATypes::f16)
return success();
break;
case NVVM::WGMMATypes::tf32:
- if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
+ if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
return success();
break;
case NVVM::WGMMATypes::u8:
case NVVM::WGMMATypes::s8:
- if (typeD.isInteger(32) &&
+ if (typeD == NVVM::WGMMATypes::s32 &&
(typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
return success();
break;
case NVVM::WGMMATypes::b1:
- if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
+ if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
return success();
break;
case NVVM::WGMMATypes::bf16:
- if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ typeB == NVVM::WGMMATypes::bf16)
return success();
break;
case NVVM::WGMMATypes::e4m3:
case NVVM::WGMMATypes::e5m2:
- if ((typeD.isF32() || typeD.isF16()) &&
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
(typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
return success();
break;
+ case WGMMATypes::f32:
+ case WGMMATypes::s32:
+ llvm_unreachable("unsupported input types");
+ break;
}
return failure();
}
@@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
switch (typeA) {
- case mlir::NVVM::WGMMATypes::f16:
- case mlir::NVVM::WGMMATypes::tf32:
- case mlir::NVVM::WGMMATypes::bf16:
- case mlir::NVVM::WGMMATypes::e4m3:
- case mlir::NVVM::WGMMATypes::e5m2:
+ case WGMMATypes::f16:
+ case WGMMATypes::tf32:
+ case WGMMATypes::bf16:
+ case WGMMATypes::e4m3:
+ case WGMMATypes::e5m2:
if (llvm::is_contained(allowedN, sizeN))
return success();
break;
- case mlir::NVVM::WGMMATypes::u8:
- case mlir::NVVM::WGMMATypes::s8:
- case mlir::NVVM::WGMMATypes::b1:
+ case WGMMATypes::u8:
+ case WGMMATypes::s8:
+ case WGMMATypes::b1:
if (llvm::is_contained(allowedNshort, sizeN))
return success();
+ break;
+ case WGMMATypes::f32:
+ case WGMMATypes::s32:
+ llvm_unreachable("unsupported input types");
+ break;
}
return failure();
}
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
if (!stype)
return emitOpError() << "expected results to be struct";
- Type outputType = stype.getBody().front();
int outputSize = stype.getBody().size();
+ WGMMATypes typeD = getTypeD();
+ WGMMATypes typeA = getTypeA();
+ WGMMATypes typeB = getTypeB();
+
for (Type t : stype.getBody()) {
- if (t != outputType)
+ if (t != stype.getBody().front())
return emitOpError()
<< "all elements in struct must be same type but there is " << t;
}
- if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+ if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+ typeD != WGMMATypes::s32) {
return emitOpError() << "does not support the given output type "
- << outputType;
+ << NVVM::stringifyWGMMATypes(typeD);
}
- if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
- getScaleB() == NVVM::WGMMAScaleIn::neg)) {
+ if (typeD == WGMMATypes::s32 &&
+ (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
}
- mlir::NVVM::WGMMATypes typeA = getTypeA();
- mlir::NVVM::WGMMATypes typeB = getTypeB();
- if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
- return emitOpError() << outputType
+ if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
+ return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
<< " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
<< NVVM::stringifyWGMMATypes(typeB)
<< ", it is not supported.";
@@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
}
// Check transpose (only available for f16/bf16)
- if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
- typeA != mlir::NVVM::WGMMATypes::bf16) &&
+ if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
getLayoutB() == mlir::NVVM::MMALayout::col)) {
return emitOpError()
@@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
<< " for input types " << stringifyWGMMATypes(typeA) << " and "
<< stringifyWGMMATypes(typeB)
<< " requires transpose. However, this is only supported for: "
- << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
- << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
+ << stringifyMMATypes(MMATypes::f16) << " and "
+ << stringifyMMATypes(MMATypes::bf16);
}
// Check result registers
- int expectedOutput;
- if (outputType.isF32() || outputType.isInteger(32))
+ int expectedOutput = 0;
+ if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
expectedOutput = getShape().getN() / 2;
- if (outputType.isF16())
+ if (typeD == WGMMATypes::f16)
expectedOutput = getShape().getN() / 4;
if (outputSize != expectedOutput) {
return emitOpError() << "results " << expectedOutput
<< ", however output struct has " << outputSize
<< " elements";
}
- // Check satfinite (only availalbe for s32 accumulator)
- if (!outputType.isInteger(32) &&
+ // Check satfinite (only available for s32 accumulator)
+ if (typeD != WGMMATypes::s32 &&
getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite) {
return emitOpError()
<< " `satfinite` can be only used with s32 accumulator, however "
"the current accumulator is "
- << outputType;
+ << NVVM::stringifyWGMMATypes(typeD);
}
return success();
@@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
- bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
- getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+ bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
- Value outValue = getResults() ? getResults() : getInouts();
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
- Type outputType = stype.getBody().front();
- std::string outputTypeName;
- if (outputType.isF16())
- outputTypeName = "f16";
- else if (outputType.isF32())
- outputTypeName = "f32";
- else if (outputType.isInteger(32))
- outputTypeName = "s32";
- else
- assert(false && "unsupported output type");
+ StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
- int expectedOutputRegisters;
- if (outputType.isF32() || outputType.isInteger(32))
- expectedOutputRegisters = getShape().getN() / 2;
- if (outputType.isF16())
+ int expectedOutputRegisters = 0;
+ if (getTypeD() == WGMMATypes::f16)
expectedOutputRegisters = getShape().getN() / 4;
+ else
+ expectedOutputRegisters = getShape().getN() / 2;
std::string ptx;
llvm::raw_string_ostream ss(ptx);
@@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << " $" << (regCnt) << ","
<< " $" << (regCnt + 1) << ","
<< " p";
- if (!outputType.isInteger(32)) {
+ if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
// Don't add transpose parameters unless needed.
@@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
- Value outValue = getResults() ? getResults() : getInouts();
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
- Type outputType = stype.getBody().front();
- bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
- getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+ bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
if (getResults())
asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
if (getInouts())
@@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
mlir::NVVM::PTXRegisterMod::Read});
- if (!outputType.isInteger(32)) {
+ if (getTypeD() != WGMMATypes::s32) {
asmValues.push_back(
{makeConstantI32(rewriter,
getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b495363e228d8f4..b25dd76bf12037b 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -880,41 +880,41 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S2]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64
-// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], %[[S4]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64
// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64
// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64
// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64
-// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], %[[S9]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64
// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64
// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64
-// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], %[[S14]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64
-// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], %[[S3]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64
// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64
// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64
// CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]] : i64
-// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], <m = 64, n = 128, k = 16>, D[%[[S23]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], %[[S23]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64
// CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]] : i64
// CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64
// CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]] : i64
-// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], <m = 64, n = 128, k = 16>, D[%[[S28]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], %[[S28]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64
// CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]] : i64
// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64
// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64
-// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], %[[S33]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S40:.+]] = llvm.insertvalue %[[S19]], %[[UD]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S41:.+]] = llvm.insertvalue %[[S38]], %[[S40]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.commit.group.sync.aligned
@@ -1299,7 +1299,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, <m = 64, n = 128, k = 16>, D[%[[S138]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.mma_async
// CHECK: nvvm.wgmma.mma_async
// CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 34c8de9f7ed8c6d..9ebe3a009adf258 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -4,9 +4,9 @@
func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
%result = llvm.mlir.undef : !mat64f32
// expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 128, k = 16>,
- D [%result, <zero>],
+ D [<f32>, <zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !mat64f32 -> !mat64f32
@@ -17,10 +17,10 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is f32}}
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 16, k = 16>,
- D [%result, <zero>, <satfinite>],
+ D [<f32>, <zero>, <satfinite>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -33,9 +33,9 @@ func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) {
func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// expected-error @+1 {{shape 'm' must be 64}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 32, n = 16, k = 16>,
- D [%result, <zero>],
+ D [<f32>, <zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -48,9 +48,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)>
// expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 16, k = 16>,
- D [%result, <zero>],
+ D [<f32>, <zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)>
@@ -63,9 +63,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// expected-error @+1 {{op shape 'k' must be 16 for input type f16}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 16, k = 3>,
- D [%result, <zero>],
+ D [<f32>, <zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -78,9 +78,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
func.func @wgmma_transpose(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 16, k = 8>,
- D [%result, <zero>],
+ D [<f32>, <zero>],
A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -92,10 +92,10 @@ func.func @wgmma_transpose(%descA : i64, %descB : i64) {
func.func @wgmma_transpose(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16)>
- // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ // expected-error @+1 {{'nvvm.wgmma.mma_async' op f16 += tf32 * tf32, it is not supported.}}
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 16, k = 8>,
- D [%result, <zero>],
+ D [<f16>, <zero>],
A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
:!llvm.struct<(f16, f16, f16, f16)>
@@ -108,9 +108,9 @@ func.func @wgmma_transpose(%descA : i64, %descB : i64) {
func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)>
// expected-error @+1 {{input struct and result struct must be the same type}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 8, k = 16>,
- D [%result, <zero>],
+ D [<f16>, <zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(i32, i32, i32, i32)>
@@ -122,10 +122,10 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- // expected-error @+1 {{op 'f32' += bf16 * f16, it is not supported}}
- %res = nvvm.wgmma.mma_async %descA, %descB,
+ // expected-error @+1 {{op f32 += bf16 * f16, it is not supported}}
+ %res = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 8, k = 16>,
- D [%result, <zero>],
+ D [<f32>, <zero>],
A [<bf16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218a..9c7c27c49eb11d6 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -329,9 +329,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
// CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
%result = llvm.mlir.undef : !mat64f32
%result1 = nvvm.wgmma.mma_async
- %descA, %descB,
+ %descA, %descB, %result,
#nvvm.shape<m = 64, n = 32, k = 16>,
- D [%result, #nvvm.wgmma_scale_out<zero>],
+ D [<f32>, #nvvm.wgmma_scale_out<zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
:!mat64f32 -> !mat64f32
@@ -339,9 +339,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
%descAnext = arith.addi %descA, %c2 : i64
%descBnext = arith.addi %descB, %c2 : i64
%result2 = nvvm.wgmma.mma_async
- %descAnext, %descBnext,
+ %descAnext, %descBnext, %result1,
#nvvm.shape<m = 64, n = 32, k = 16>,
- D [%result1, #nvvm.wgmma_scale_out<zero>],
+ D [<f32>, #nvvm.wgmma_scale_out<zero>],
A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
: !mat64f32 -> !mat64f32
@@ -393,21 +393,21 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n"
// CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
- %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 8, k = 32>,
- D [%result, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
: !mat16i32 -> !mat16i32
- %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 8, k = 32>,
- D [%result1, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
: !mat16i32 -> !mat16i32
- %result3 = nvvm.wgmma.mma_async %descA, %descB,
+ %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
#nvvm.shape<m = 64, n = 8, k = 32>,
- D [%result2, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
: !mat16i32 -> !mat16i32
@@ -454,21 +454,21 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
// CHECK-SAME:}\0A",
// CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
%result = llvm.mlir.undef : !mat16i32
- %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 8, k = 32>,
- D [%result, #nvvm.wgmma_scale_out<one>],
+ D [<s32>, #nvvm.wgmma_scale_out<one>],
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
: !mat16i32 -> !mat16i32
- %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 8, k = 32>,
- D [%result1, #nvvm.wgmma_scale_out<one>],
+ D [<s32>, #nvvm.wgmma_scale_out<one>],
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
: !mat16i32 -> !mat16i32
- %result3 = nvvm.wgmma.mma_async %descA, %descB,
+ %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
#nvvm.shape<m = 64, n = 8, k = 32>,
- D [%result2, #nvvm.wgmma_scale_out<one>],
+ D [<s32>, #nvvm.wgmma_scale_out<one>],
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
: !mat16i32 -> !mat16i32
@@ -496,15 +496,15 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
// CHECK-SAME: setp.ne.b32 p, $66, 0;
// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
%result = llvm.mlir.undef : !mat32f32
- %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 64, k = 8>,
- D [%result, #nvvm.wgmma_scale_out<one>],
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
- %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 64, k = 8>,
- D [%result1, #nvvm.wgmma_scale_out<one>],
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
@@ -529,15 +529,15 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
// CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
%result = llvm.mlir.undef : !mat32f32
- %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 64, k = 32>,
- D [%result, #nvvm.wgmma_scale_out<one>],
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
- %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 64, k = 32>,
- D [%result1, #nvvm.wgmma_scale_out<one>],
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
@@ -561,15 +561,15 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
// CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
%result = llvm.mlir.undef : !mat32f32
- %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
#nvvm.shape<m = 64, n = 64, k = 32>,
- D [%result, #nvvm.wgmma_scale_out<one>],
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
- %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 64, k = 32>,
- D [%result1, #nvvm.wgmma_scale_out<one>],
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 36aaaea79b18667..0eef97d95479a7c 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -32,7 +32,7 @@ def wgmma_f32_f16_f16(desc_a, desc_b):
nvvm.CpAsyncWaitGroupOp(5)
# CHECK: %0 = llvm.mlir.undef : [[MAT_T:.*]]
result = llvm.UndefOp(mat64f32_t)
- # CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, <m = 64, n = 32, k = 16>, D[%0, <zero>], A[<f16>, <neg>, <col>], B[<f16>, <neg>, <col>] : [[MAT_T]] -> [[MAT_T]]
+ # CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, %0, <m = 64, n = 32, k = 16>, D[<f32>, <zero>], A[<f16>, <neg>, <col>], B[<f16>, <neg>, <col>] : [[MAT_T]] -> [[MAT_T]]
result1 = nvvm.WgmmaMmaAsyncOp(
results_=mat64f32_t,
inouts=result,
@@ -41,6 +41,7 @@ def wgmma_f32_f16_f16(desc_a, desc_b):
shape=shape_attr,
typeA=nvvm.WGMMATypes.f16,
typeB=nvvm.WGMMATypes.f16,
+ typeD=nvvm.WGMMATypes.f32,
scaleD=nvvm.WGMMAScaleOut.zero,
scaleA=nvvm.WGMMAScaleIn.neg,
scaleB=nvvm.WGMMAScaleIn.neg,
More information about the Mlir-commits
mailing list