[Mlir-commits] [mlir] 5c3150e - [MLIR][NVVM] Introduce WGMMA Types
Guray Ozen
llvmlistbot at llvm.org
Sat Aug 12 03:47:50 PDT 2023
Author: Guray Ozen
Date: 2023-08-12T12:47:45+02:00
New Revision: 5c3150e584b6449ac80434ed45a9bfb0188c444c
URL: https://github.com/llvm/llvm-project/commit/5c3150e584b6449ac80434ed45a9bfb0188c444c
DIFF: https://github.com/llvm/llvm-project/commit/5c3150e584b6449ac80434ed45a9bfb0188c444c.diff
LOG: [MLIR][NVVM] Introduce WGMMA Types
This work introduces `WGMMATypes` attributes for the `WgmmaMmaSyncOp`. This op, having been recently added to MLIR, previously used `MMATypes`. However, there arises a disparity in supported types between `MmaOp` and `WgmmaMmaSyncOp`. To address this discrepancy more effectively, a new set of attributes is introduced.
Furthermore, this patch refines and optimizing the verification mechanisms of `WgmmaMmaSyncOp` Op.
It also adds support for f8 types, including `e4m3` and `e5m2`, within the `WgmmaMmaSyncOp`.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D157695
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Conversion/NVVMToLLVM/invalid.mlir
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7ddf68e0a5b1a4..cfee215c2e3517 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1497,6 +1497,28 @@ def WGMMAScaleOutAttr : EnumAttr<NVVM_Dialect, WGMMAScaleOut, "wgmma_scale_out">
let assemblyFormat = "`<` $value `>`";
}
+/// Enum attribute of the
diff erent PTX element types used for WGMMA operands.
+def WGMMATypeF16 : I32EnumAttrCase<"f16", 0>;
+def WGMMATypeTF32 : I32EnumAttrCase<"tf32", 1>;
+def WGMMATypeU8 : I32EnumAttrCase<"u8", 2>;
+def WGMMATypeS8 : I32EnumAttrCase<"s8", 3>;
+def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
+def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
+def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
+def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
+def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
+ [WGMMATypeF16, WGMMATypeTF32,
+ WGMMATypeU8, WGMMATypeS8,
+ WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3,
+ WGMMATypeF8E5M2]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def WGMMATypesAttr : EnumAttr<NVVM_Dialect, WGMMATypes, "wgmma_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+
def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
PredOpTrait<"input struct and result struct must be the same type",
@@ -1508,15 +1530,14 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
I64:$descriptorA,
I64:$descriptorB,
NVVM_MMAShapeAttr:$shape,
- MMATypesAttr:$typeA,
- MMATypesAttr:$typeB,
+ WGMMATypesAttr:$typeA,
+ WGMMATypesAttr:$typeB,
WGMMAScaleOutAttr:$scaleD,
WGMMAScaleInAttr:$scaleA,
WGMMAScaleInAttr:$scaleB,
MMALayoutAttr:$layoutA,
MMALayoutAttr:$layoutB,
OptionalAttr<MMAIntOverflowAttr>:$satfinite
- // OptionalAttr<UnitAttr>:$satfinite
);
let assemblyFormat = [{
@@ -1536,44 +1557,50 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
Supported shapes:
```
- |-------------------|--------------------|----------------|---------------|
- | | f16 += f16 * f16 | s32 += s8 * s8 | |
- |f32 += tf32 * tf32 | f32 += f16 * f16 | s32 += s8 * u8 |s32 += b1 * b1 |
- | | f32 += bf16 * bf16 | s32 += u8 * u8 | |
- |-------------------|--------------------|----------------|---------------|
- | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 |
- | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 |
- | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 |
- | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 |
- | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 |
- | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 |
- | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 |
- | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 |
- | .m64n72k8 | .m64n72k16 | .m64n112k32 | .m64n112k256 |
- | .m64n80k8 | .m64n80k16 | .m64n128k32 | .m64n128k256 |
- | .m64n88k8 | .m64n88k16 | .m64n144k32 | .m64n144k256 |
- | .m64n96k8 | .m64n96k16 | .m64n160k32 | .m64n160k256 |
- | .m64n104k8 | .m64n104k16 | .m64n176k32 | .m64n176k256 |
- | .m64n112k8 | .m64n112k16 | .m64n192k32 | .m64n192k256 |
- | .m64n120k8 | .m64n120k16 | .m64n208k32 | .m64n208k256 |
- | .m64n128k8 | .m64n128k16 | .m64n224k32 | .m64n224k256 |
- | .m64n136k8 | .m64n136k16 | .m64n240k32 | .m64n240k256 |
- | .m64n144k8 | .m64n144k16 | .m64n256k32 | .m64n256k256 |
- | .m64n152k8 | .m64n152k16 | | |
- | .m64n160k8 | .m64n160k16 | | |
- | .m64n168k8 | .m64n168k16 | | |
- | .m64n176k8 | .m64n176k16 | | |
- | .m64n184k8 | .m64n184k16 | | |
- | .m64n192k8 | .m64n192k16 | | |
- | .m64n200k8 | .m64n200k16 | | |
- | .m64n208k8 | .m64n208k16 | | |
- | .m64n216k8 | .m64n216k16 | | |
- | .m64n224k8 | .m64n224k16 | | |
- | .m64n232k8 | .m64n232k16 | | |
- | .m64n240k8 | .m64n240k16 | | |
- | .m64n248k8 | .m64n248k16 | | |
- | .m64n256k8 | .m64n256k16 | | |
- |-------------------|--------------------|----------------|---------------|
+ |--------------|--------------|------------|--------------|---------------|
+ | | | | |f16+=e4m3*e4m3 |
+ | | | | |f16+=e5m2*e5m2 |
+ |f32+=tf32*tf32|f16+=f16 *f16 | s32+=s8*s8 |s32 += b1 * b1|f16+=e5m2*e4m3 |
+ | |f32+=f16 *f16 | s32+=u8*u8 | |f16+=e4m3*e5m2 |
+ | |f32+=bf16*bf16| s32+=u8*u8 | |f16+=e4m3*e5m2 |
+ | |f32+=bf16*bf16| s32+=s8*u8 | |f32+=e4m3*e4m3 |
+ | | | s32+=u8*s8 | |f32+=e5m2*e5m2 |
+ | | | | |f32+=e4m3*e5m2 |
+ | | | | |f32+=e4m3*e5m2 |
+ |--------------|--------------|------------|--------------|---------------|
+ | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | .m64n8k32 |
+ | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | .m64n16k32 |
+ | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | .m64n24k32 |
+ | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | .m64n32k32 |
+ | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | .m64n40k32 |
+ | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | .m64n48k32 |
+ | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | .m64n56k32 |
+ | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | .m64n64k32 |
+ | .m64n72k8 | .m64n72k16 | .m64n112k32| .m64n112k256 | .m64n72k32 |
+ | .m64n80k8 | .m64n80k16 | .m64n128k32| .m64n128k256 | .m64n80k32 |
+ | .m64n88k8 | .m64n88k16 | .m64n144k32| .m64n144k256 | .m64n88k32 |
+ | .m64n96k8 | .m64n96k16 | .m64n160k32| .m64n160k256 | .m64n96k32 |
+ | .m64n104k8 | .m64n104k16 | .m64n176k32| .m64n176k256 | .m64n104k32 |
+ | .m64n112k8 | .m64n112k16 | .m64n192k32| .m64n192k256 | .m64n112k32 |
+ | .m64n120k8 | .m64n120k16 | .m64n208k32| .m64n208k256 | .m64n120k32 |
+ | .m64n128k8 | .m64n128k16 | .m64n224k32| .m64n224k256 | .m64n128k32 |
+ | .m64n136k8 | .m64n136k16 | .m64n240k32| .m64n240k256 | .m64n136k32 |
+ | .m64n144k8 | .m64n144k16 | .m64n256k32| .m64n256k256 | .m64n144k32 |
+ | .m64n152k8 | .m64n152k16 | | | .m64n152k32 |
+ | .m64n160k8 | .m64n160k16 | | | .m64n160k32 |
+ | .m64n168k8 | .m64n168k16 | | | .m64n168k32 |
+ | .m64n176k8 | .m64n176k16 | | | .m64n176k32 |
+ | .m64n184k8 | .m64n184k16 | | | .m64n184k32 |
+ | .m64n192k8 | .m64n192k16 | | | .m64n192k32 |
+ | .m64n200k8 | .m64n200k16 | | | .m64n200k32 |
+ | .m64n208k8 | .m64n208k16 | | | .m64n208k32 |
+ | .m64n216k8 | .m64n216k16 | | | .m64n216k32 |
+ | .m64n224k8 | .m64n224k16 | | | .m64n224k32 |
+ | .m64n232k8 | .m64n232k16 | | | .m64n232k32 |
+ | .m64n240k8 | .m64n240k16 | | | .m64n240k32 |
+ | .m64n248k8 | .m64n248k16 | | | .m64n248k32 |
+ | .m64n256k8 | .m64n256k16 | | | .m64n256k32 |
+ |--------------|--------------|------------|--------------|---------------|
```
See for more information:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 70c774233797bb..d6794881b442ee 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -37,6 +37,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <cassert>
#include <optional>
#include <string>
@@ -708,6 +709,81 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
+ if (typeA == NVVM::WGMMATypes::tf32)
+ return 8;
+ if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
+ return 16;
+ if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
+ return 32;
+ if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
+ return 32;
+ if (typeA == NVVM::WGMMATypes::b1)
+ return 256;
+ return failure();
+}
+
+LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
+ NVVM::WGMMATypes typeB) {
+ switch (typeA) {
+ case NVVM::WGMMATypes::f16:
+ if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
+ return success();
+ break;
+ case NVVM::WGMMATypes::tf32:
+ if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
+ return success();
+ break;
+ case NVVM::WGMMATypes::u8:
+ case NVVM::WGMMATypes::s8:
+ if (typeD.isInteger(32) &&
+ (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
+ return success();
+ break;
+ case NVVM::WGMMATypes::b1:
+ if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
+ return success();
+ break;
+ case NVVM::WGMMATypes::bf16:
+ if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
+ return success();
+ break;
+ case NVVM::WGMMATypes::e4m3:
+ case NVVM::WGMMATypes::e5m2:
+ if ((typeD.isF32() || typeD.isF16()) &&
+ (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
+ return success();
+ break;
+ }
+ return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
+ SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
+ 72, 80, 88, 96, 104, 112, 120, 128,
+ 136, 144, 152, 160, 168, 176, 184, 192,
+ 200, 208, 216, 224, 232, 240, 248, 256};
+ SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
+ 80, 96, 112, 128, 144, 160,
+ 176, 192, 208, 224, 240, 256};
+ switch (typeA) {
+ case mlir::NVVM::WGMMATypes::f16:
+ case mlir::NVVM::WGMMATypes::tf32:
+ case mlir::NVVM::WGMMATypes::bf16:
+ case mlir::NVVM::WGMMATypes::e4m3:
+ case mlir::NVVM::WGMMATypes::e5m2:
+ if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; }))
+ return success();
+ break;
+ case mlir::NVVM::WGMMATypes::u8:
+ case mlir::NVVM::WGMMATypes::s8:
+ case mlir::NVVM::WGMMATypes::b1:
+ if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; }))
+ return success();
+ }
+ return failure();
+}
+
LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
Value outValue = getResults();
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
@@ -730,142 +806,49 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
}
+ mlir::NVVM::WGMMATypes typeA = getTypeA();
+ mlir::NVVM::WGMMATypes typeB = getTypeB();
+ if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
+ return emitOpError() << outputType
+ << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
+ << NVVM::stringifyWGMMATypes(typeB)
+ << ", it is not supported.";
+ }
+
// Check M
if (getShape().getM() != 64)
return emitOpError() << "shape 'm' must be 64";
// Check K
- mlir::NVVM::MMATypes typeA = getTypeA();
- mlir::NVVM::MMATypes typeB = getTypeB();
- switch (typeA) {
- case mlir::NVVM::MMATypes::bf16:
- case mlir::NVVM::MMATypes::f16:
- if (typeA != typeB) {
- return emitOpError() << "input types must be same but got "
- << NVVM::stringifyMMATypes(typeA) << " and "
- << NVVM::stringifyMMATypes(typeB);
- }
- if (getShape().getK() != 16) {
- return emitOpError() << "shape 'k' must be 16 "
- << "for input type "
- << NVVM::stringifyMMATypes(typeA);
- }
- break;
- case mlir::NVVM::MMATypes::tf32:
- if (typeA != typeB) {
- return emitOpError() << "input types must be same but got "
- << NVVM::stringifyMMATypes(typeA) << " and "
- << NVVM::stringifyMMATypes(typeB);
- }
- if (getShape().getK() != 8) {
- return emitOpError() << "shape 'k' must be 8 "
- << "for input type "
- << NVVM::stringifyMMATypes(typeA);
- }
- break;
- case mlir::NVVM::MMATypes::s8:
- case mlir::NVVM::MMATypes::u8:
- if (typeB != mlir::NVVM::MMATypes::s8 &&
- typeB != mlir::NVVM::MMATypes::u8) {
- return emitOpError() << "input type of rhs could be "
- << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8)
- << " or "
- << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8)
- << " same but got and "
- << NVVM::stringifyMMATypes(typeB);
- }
- if (getShape().getK() != 32) {
- emitOpError() << "shape 'k' must be 32 "
- << "for input type " << NVVM::stringifyMMATypes(typeA);
- }
- break;
- case mlir::NVVM::MMATypes::b1:
- if (typeA != typeB) {
- return emitOpError() << "input types must be same but got "
- << NVVM::stringifyMMATypes(typeA) << " and "
- << NVVM::stringifyMMATypes(typeB);
- }
- if (getShape().getK() != 256) {
- return emitOpError() << "shape 'k' must be 256 "
- << "for input type "
- << NVVM::stringifyMMATypes(typeA);
- }
- break;
- default:
- return emitOpError() << "Unsupported input type "
- << NVVM::stringifyMMATypes(typeA) << " and "
- << NVVM::stringifyMMATypes(typeB);
- }
+ FailureOr<int> allowedK = getAllowedSizeK(typeA);
+ if (failed(allowedK) || allowedK.value() != getShape().getK())
+ return emitOpError() << "shape 'k' must be " << allowedK.value()
+ << " for input type "
+ << NVVM::stringifyWGMMATypes(typeA);
// Check N
- SmallVector<int> allowedNShapesF16 = {8, 16, 24, 32, 40, 48, 56, 64,
- 72, 80, 88, 96, 104, 112, 120, 128,
- 136, 144, 152, 160, 168, 176, 184, 192,
- 200, 208, 216, 224, 232, 240, 248, 256};
- SmallVector<int> allowedNShapesU8S8B1 = {8, 16, 24, 32, 48, 64,
- 80, 96, 112, 128, 144, 160,
- 176, 192, 208, 224, 240, 256};
-
- bool validGEMMType = false;
- // f16 += f16 * f16
- if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) {
- if (!llvm::any_of(allowedNShapesF16,
- [&](int n) { return getShape().getN() == n; })) {
- return emitOpError() << "has input type "
- << NVVM::stringifyMMATypes(typeA) << " n is set to "
- << getShape().getN() << ", it is not supported.";
- }
- validGEMMType = true;
- }
- // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16
- if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 ||
- typeA == mlir::NVVM::MMATypes::tf32 ||
- typeA == mlir::NVVM::MMATypes::f16)) {
- if (!llvm::any_of(allowedNShapesF16,
- [&](int n) { return getShape().getN() == n; })) {
- return emitOpError() << "has input type "
- << NVVM::stringifyMMATypes(typeA) << " n is set to "
- << getShape().getN() << ", it is not supported.";
- }
- validGEMMType = true;
- }
- // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1
- if (outputType.isInteger(32) &&
- (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 ||
- typeA == mlir::NVVM::MMATypes::b1)) {
- if (!llvm::any_of(allowedNShapesU8S8B1,
- [&](int n) { return getShape().getN() == n; })) {
- return emitOpError() << "has input type "
- << NVVM::stringifyMMATypes(typeA) << " n is set to "
- << getShape().getN() << ", it is not supported.";
- }
- validGEMMType = true;
+ if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
+ return emitOpError() << "has input type "
+ << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
+ << getShape().getN() << ", it is not supported.";
}
- if (!validGEMMType) {
- return emitOpError() << outputType
- << " += " << NVVM::stringifyMMATypes(typeA) << " * "
- << NVVM::stringifyMMATypes(typeB)
- << ", it is not supported.";
- }
-
- // Check transpose is needed from the given layouts. It is only
- // supported for bf16 or f16.
- if ((typeA != mlir::NVVM::MMATypes::f16 &&
- typeA != mlir::NVVM::MMATypes::bf16) &&
+ // Check transpose (only available for f16/bf16)
+ if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
+ typeA != mlir::NVVM::WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
getLayoutB() == mlir::NVVM::MMALayout::col)) {
return emitOpError()
<< "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
<< " and layout_b = " << stringifyMMALayout(getLayoutB())
- << " for input types " << stringifyMMATypes(typeA) << " and "
- << stringifyMMATypes(typeB)
+ << " for input types " << stringifyWGMMATypes(typeA) << " and "
+ << stringifyWGMMATypes(typeB)
<< " requires transpose. However, this is only supported for: "
<< stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
<< stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
}
- // Check number of result registers
+ // Check result registers
int expectedOutput;
if (outputType.isF32() || outputType.isInteger(32))
expectedOutput = getShape().getN() / 2;
@@ -876,7 +859,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
<< ", however output struct has " << outputSize
<< " elements";
}
- // Check satfinite is set. It is only for s32 accumulator
+ // Check satfinite (only availalbe for s32 accumulator)
if (!outputType.isInteger(32) &&
getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite) {
@@ -892,8 +875,8 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
- bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
- getTypeA() == mlir::NVVM::MMATypes::bf16;
+ bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
+ getTypeA() == mlir::NVVM::WGMMATypes::bf16;
Value outValue = getResults() ? getResults() : getInouts();
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
@@ -901,10 +884,13 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
std::string outputTypeName;
if (outputType.isF16())
outputTypeName = "f16";
- if (outputType.isF32())
+ else if (outputType.isF32())
outputTypeName = "f32";
else if (outputType.isInteger(32))
outputTypeName = "s32";
+ else
+ assert(false && "unsupported output type");
+
int expectedOutputRegisters;
if (outputType.isF32() || outputType.isInteger(32))
expectedOutputRegisters = getShape().getN() / 2;
@@ -921,7 +907,8 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
<< ", 0;\n"
"wgmma.mma_async.sync.aligned.m"
<< m << "n" << n << "k" << k << "." << outputTypeName << "."
- << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB());
+ << stringifyWGMMATypes(getTypeA()) << "."
+ << stringifyWGMMATypes(getTypeB());
if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite)
ss << ".satfinite";
@@ -959,8 +946,8 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
Value outValue = getResults() ? getResults() : getInouts();
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
Type outputType = stype.getBody().front();
- bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
- getTypeA() == mlir::NVVM::MMATypes::bf16;
+ bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
+ getTypeA() == mlir::NVVM::WGMMATypes::bf16;
if (getResults())
asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
if (getInouts())
diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 7f8fa5267c343f..8b9df5fa598014 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -62,7 +62,7 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
%result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- // expected-error @+1 {{op shape 'k' must be 16 for input type f16}}
+ // expected-error @+1 {{op shape 'k' must be 16 for input type f16}}
%res = nvvm.wgmma.mma_async %descA, %descB,
#nvvm.shape<m = 64, n = 16, k = 3>,
D [%result, <zero>],
@@ -116,4 +116,19 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
: !llvm.struct<(i32, i32, i32, i32)>
-> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
return
-}
\ No newline at end of file
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error @+1 {{op 'f32' += bf16 * f16, it is not supported}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 16>,
+ D [%result, <zero>],
+ A [<bf16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ return
+}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index e1e2bcba40a256..3bcb41c18d4fb4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -258,14 +258,71 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
%result1 = nvvm.wgmma.mma_async %descA, %descB,
#nvvm.shape<m = 64, n = 64, k = 8>,
D [%result, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
%result2 = nvvm.wgmma.mma_async %descA, %descB,
#nvvm.shape<m = 64, n = 64, k = 8>,
D [%result1, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ : !mat32f32 -> !mat32f32
+ return %result2 : !mat32f32
+}
+
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e4m3_e4m3
+func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
+ // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ %result = llvm.mlir.undef : !mat32f32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [%result, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ : !mat32f32 -> !mat32f32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [%result1, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ : !mat32f32 -> !mat32f32
+ return %result2 : !mat32f32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e5m2_e4m3
+func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
+ // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ %result = llvm.mlir.undef : !mat32f32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [%result, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ : !mat32f32 -> !mat32f32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [%result1, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
: !mat32f32 -> !mat32f32
return %result2 : !mat32f32
}
More information about the Mlir-commits
mailing list