[Mlir-commits] [mlir] 18e161f - [MLIR][NVVM] Introduction of the `wgmma.mma_async` Op
Guray Ozen
llvmlistbot at llvm.org
Wed Aug 9 14:08:06 PDT 2023
Author: Guray Ozen
Date: 2023-08-09T23:08:00+02:00
New Revision: 18e161f9e15b036faf48bfd8813d9330e06e2ee3
URL: https://github.com/llvm/llvm-project/commit/18e161f9e15b036faf48bfd8813d9330e06e2ee3
DIFF: https://github.com/llvm/llvm-project/commit/18e161f9e15b036faf48bfd8813d9330e06e2ee3.diff
LOG: [MLIR][NVVM] Introduction of the `wgmma.mma_async` Op
This work introduces the `wgmma.mma_async` Op along PTX generation using `BasicPtxBuilderOpInterface`. The Op is designed to execute the matrix multiply-and-accumulate operation across a warpgroup (128 threads). It's important to note that this operation works for devices with the sm_90a capability.
The matrix multiply-and-accumulate operation can take one of the following forms. In both cases, matrix D is referred to as the accumulator:
D = A * B + D : Result is added to the accumulator matrix D.
D = A * B : The input from the accumulator matrix D is not utilized.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D157370
Added:
mlir/test/Conversion/NVVMToLLVM/invalid.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 490a0db9baa028..4845a226ec9c17 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -168,7 +168,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
/*desc=*/[{Generate constant value.}],
/*retType=*/"::mlir::Value",
/*methodName=*/"makeConstantI32",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "unsigned" : $val),
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;
@@ -1473,6 +1473,121 @@ def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
}];
}
+/// Enum attribute type for the negating of input operands
+def WGMMAScaleInNeg : I32EnumAttrCase<"neg", -1>;
+def WGMMAScaleInOne : I32EnumAttrCase<"one", 1>;
+def WGMMAScaleIn : I32EnumAttr<"WGMMAScaleIn", "WGMMA overflow options",
+ [WGMMAScaleInOne, WGMMAScaleInNeg]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def WGMMAScaleInAttr : EnumAttr<NVVM_Dialect, WGMMAScaleIn, "wgmma_scale_in"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// Enum attribute type for the output operand
+def WGMMAScaleOutZero : I32EnumAttrCase<"zero", 0>;
+def WGMMAScaleOutOne : I32EnumAttrCase<"one", 1>;
+def WGMMAScaleOut : I32EnumAttr<"WGMMAScaleOut", "WGMMA input predicate",
+ [WGMMAScaleOutZero, WGMMAScaleOutOne]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def WGMMAScaleOutAttr : EnumAttr<NVVM_Dialect, WGMMAScaleOut, "wgmma_scale_out"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_WgmmaMmaSyncOp : NVVM_Op<"wgmma.mma_async",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ PredOpTrait<"input struct and result struct must be the same type",
+ TCresIsSameAsOpBase<0, 0>>,]>
+{
+ let results = (outs LLVM_AnyAggregate:$results);
+ let arguments = (ins
+ LLVM_AnyAggregate:$inouts,
+ I64:$descriptorA,
+ I64:$descriptorB,
+ NVVM_MMAShapeAttr:$shape,
+ MMATypesAttr:$typeA,
+ MMATypesAttr:$typeB,
+ WGMMAScaleOutAttr:$scaleD,
+ WGMMAScaleInAttr:$scaleA,
+ WGMMAScaleInAttr:$scaleB,
+ MMALayoutAttr:$layoutA,
+ MMALayoutAttr:$layoutB,
+ OptionalAttr<MMAIntOverflowAttr>:$satfinite
+ // OptionalAttr<UnitAttr>:$satfinite
+ );
+
+ let assemblyFormat = [{
+ $descriptorA `,` $descriptorB `,` $shape `,`
+ `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
+ `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,`
+ `B` `[` $typeB `,` $scaleB `,` $layoutB `]`
+ attr-dict `:`
+ type($inouts) `->` type($results)
+ }];
+
+ let description = [{
+ The warpgroup (128 threads) level matrix multiply and accumulate operation
+ has either of the following forms, where matrix D is called accumulator:
+ D = A * B + D
+ D = A * B, where the input from accumulator D is disabled.
+
+ Supported shapes:
+ ```
+ |-------------------|--------------------|----------------|---------------|
+ | | f16 += f16 * f16 | s32 += s8 * s8 | |
+ |f32 += tf32 * tf32 | f32 += f16 * f16 | s32 += s8 * u8 |s32 += b1 * b1 |
+ | | f32 += bf16 * bf16 | s32 += u8 * u8 | |
+ |-------------------|--------------------|----------------|---------------|
+ | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 |
+ | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 |
+ | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 |
+ | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 |
+ | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 |
+ | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 |
+ | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 |
+ | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 |
+ | .m64n72k8 | .m64n72k16 | .m64n112k32 | .m64n112k256 |
+ | .m64n80k8 | .m64n80k16 | .m64n128k32 | .m64n128k256 |
+ | .m64n88k8 | .m64n88k16 | .m64n144k32 | .m64n144k256 |
+ | .m64n96k8 | .m64n96k16 | .m64n160k32 | .m64n160k256 |
+ | .m64n104k8 | .m64n104k16 | .m64n176k32 | .m64n176k256 |
+ | .m64n112k8 | .m64n112k16 | .m64n192k32 | .m64n192k256 |
+ | .m64n120k8 | .m64n120k16 | .m64n208k32 | .m64n208k256 |
+ | .m64n128k8 | .m64n128k16 | .m64n224k32 | .m64n224k256 |
+ | .m64n136k8 | .m64n136k16 | .m64n240k32 | .m64n240k256 |
+ | .m64n144k8 | .m64n144k16 | .m64n256k32 | .m64n256k256 |
+ | .m64n152k8 | .m64n152k16 | | |
+ | .m64n160k8 | .m64n160k16 | | |
+ | .m64n168k8 | .m64n168k16 | | |
+ | .m64n176k8 | .m64n176k16 | | |
+ | .m64n184k8 | .m64n184k16 | | |
+ | .m64n192k8 | .m64n192k16 | | |
+ | .m64n200k8 | .m64n200k16 | | |
+ | .m64n208k8 | .m64n208k16 | | |
+ | .m64n216k8 | .m64n216k16 | | |
+ | .m64n224k8 | .m64n224k16 | | |
+ | .m64n232k8 | .m64n232k16 | | |
+ | .m64n240k8 | .m64n240k16 | | |
+ | .m64n248k8 | .m64n248k16 | | |
+ | .m64n256k8 | .m64n256k16 | | |
+ |-------------------|--------------------|----------------|---------------|
+ ```
+
+ See for more information:
+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions
+ }];
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ void getAsmValues(RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 52abbe998872ab..2d7a441e950045 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -53,7 +53,7 @@ using namespace NVVM;
namespace {
class PtxBuilder {
- Operation *op;
+ NVVM::BasicPtxBuilderInterface op;
PatternRewriter &rewriter;
std::string asmStr;
SmallVector<Value> asmVals;
@@ -62,30 +62,35 @@ class PtxBuilder {
bool hasResult = false;
// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
- char getRegisterType(Value v) {
- if (v.getDefiningOp<LLVM::ConstantOp>())
- return 'n';
- if (v.getType().isInteger(16))
+ char getRegisterType(Type type) {
+ if (type.isInteger(16))
return 'h';
- if (v.getType().isInteger(32))
+ if (type.isInteger(32))
return 'r';
- if (v.getType().isInteger(64))
+ if (type.isInteger(64))
return 'l';
- if (v.getType().isF32())
+ if (type.isF32())
return 'f';
- if (v.getType().isF64())
+ if (type.isF64())
return 'd';
- if (auto ptr = v.getType().dyn_cast<LLVM::LLVMPointerType>()) {
+ if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
// Shared address spaces is addressed with 32-bit pointers.
if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
return 'r';
}
return 'l';
}
- assert(false && "Register type is not handled yet");
+ op->emitError() << "Register type could not deduced from MLIR type: "
+ << type;
return ' ';
}
+ char getRegisterType(Value v) {
+ if (v.getDefiningOp<LLVM::ConstantOp>())
+ return 'n';
+ return getRegisterType(v.getType());
+ }
+
public:
PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm,
bool sideEffects = false)
@@ -93,26 +98,60 @@ class PtxBuilder {
sideEffects(sideEffects) {}
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
- llvm::raw_string_ostream ss(asmConstraints);
- if (itype == PTXRegisterMod::Read) {
- asmVals.push_back(v);
- } else if (itype == PTXRegisterMod::ReadWrite) {
- asmVals.push_back(v);
- ss << "+";
- hasResult = true;
- } else if (itype == PTXRegisterMod::Write) {
- ss << "=";
+ LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
+ auto getModifier = [&]() -> const char * {
+ if (itype == PTXRegisterMod::ReadWrite) {
+ assert(false && "Read-Write modifier is not supported. Try setting the "
+ "same value as Write and Read seperately.");
+ return "+";
+ }
+ if (itype == PTXRegisterMod::Write) {
+ return "=";
+ }
+ return "";
+ };
+ auto addValue = [&](Value v) {
+ if (itype == PTXRegisterMod::Read) {
+ asmVals.push_back(v);
+ return;
+ }
+ if (itype == PTXRegisterMod::ReadWrite)
+ asmVals.push_back(v);
hasResult = true;
+ };
+
+ llvm::raw_string_ostream ss(asmConstraints);
+ // Handle Structs
+ if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
+ if (itype == PTXRegisterMod::Write) {
+ addValue(v);
+ }
+ for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
+ if (itype != PTXRegisterMod::Write) {
+ Value extractValue =
+ rewriter.create<LLVM::ExtractValueOp>(op->getLoc(), v, idx);
+ addValue(extractValue);
+ }
+ if (itype == PTXRegisterMod::ReadWrite) {
+ ss << idx << ",";
+ } else {
+ ss << getModifier() << getRegisterType(t) << ",";
+ }
+ ss.flush();
+ }
+ return;
}
- ss << getRegisterType(v) << ",";
+ // Handle Scalars
+ addValue(v);
+ ss << getModifier() << getRegisterType(v) << ",";
ss.flush();
}
LLVM::InlineAsmOp build() {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT);
- Type resultType = hasResult ? op->getResult(0).getType()
- : LLVM::LLVMVoidType::get(op->getContext());
+
+ auto resultTypes = op->getResultTypes();
// Remove the last comma from the constraints string.
if (!asmConstraints.empty() &&
@@ -123,7 +162,8 @@ class PtxBuilder {
std::replace(asmStr.begin(), asmStr.end(), '%', '$');
return rewriter.create<LLVM::InlineAsmOp>(
- op->getLoc(), resultType,
+ op->getLoc(),
+ /*result types=*/resultTypes,
/*operands=*/asmVals,
/*asm_string=*/llvm::StringRef(asmStr),
/*constraints=*/asmConstraints.data(),
@@ -159,6 +199,7 @@ struct PtxLowering
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
+ LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect());
op.getAsmValues(rewriter, asmValues);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f086af6b74b1b0..25f837547b46f5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -26,7 +26,9 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
@@ -34,6 +36,7 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <string>
@@ -705,6 +708,287 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+LogicalResult NVVM::WgmmaMmaSyncOp::verify() {
+ Value outValue = getResults();
+ auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+ if (!stype)
+ return emitOpError() << "expected results to be struct";
+ Type outputType = stype.getBody().front();
+ int outputSize = stype.getBody().size();
+ for (Type t : stype.getBody()) {
+ if (t != outputType)
+ return emitOpError()
+ << "all elements in struct must be same type but there is " << t;
+ }
+
+ if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+ return emitOpError() << "does not support the given output type "
+ << outputType;
+ }
+ if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
+ getScaleB() == NVVM::WGMMAScaleIn::neg)) {
+ return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+ }
+
+ // Check M
+ if (getShape().getM() != 64)
+ return emitOpError() << "shape 'm' must be 64";
+
+ // Check K
+ mlir::NVVM::MMATypes typeA = getTypeA();
+ mlir::NVVM::MMATypes typeB = getTypeB();
+ switch (typeA) {
+ case mlir::NVVM::MMATypes::bf16:
+ case mlir::NVVM::MMATypes::f16:
+ if (typeA != typeB) {
+ return emitOpError() << "input types must be same but got "
+ << NVVM::stringifyMMATypes(typeA) << " and "
+ << NVVM::stringifyMMATypes(typeB);
+ }
+ if (getShape().getK() != 16) {
+ return emitOpError() << "shape 'k' must be 16 "
+ << "for input type "
+ << NVVM::stringifyMMATypes(typeA);
+ }
+ break;
+ case mlir::NVVM::MMATypes::tf32:
+ if (typeA != typeB) {
+ return emitOpError() << "input types must be same but got "
+ << NVVM::stringifyMMATypes(typeA) << " and "
+ << NVVM::stringifyMMATypes(typeB);
+ }
+ if (getShape().getK() != 8) {
+ return emitOpError() << "shape 'k' must be 8 "
+ << "for input type "
+ << NVVM::stringifyMMATypes(typeA);
+ }
+ break;
+ case mlir::NVVM::MMATypes::s8:
+ case mlir::NVVM::MMATypes::u8:
+ if (typeB != mlir::NVVM::MMATypes::s8 &&
+ typeB != mlir::NVVM::MMATypes::u8) {
+ return emitOpError() << "input type of rhs could be "
+ << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8)
+ << " or "
+ << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8)
+ << " same but got and "
+ << NVVM::stringifyMMATypes(typeB);
+ }
+ if (getShape().getK() != 32) {
+ emitOpError() << "shape 'k' must be 32 "
+ << "for input type " << NVVM::stringifyMMATypes(typeA);
+ }
+ break;
+ case mlir::NVVM::MMATypes::b1:
+ if (typeA != typeB) {
+ return emitOpError() << "input types must be same but got "
+ << NVVM::stringifyMMATypes(typeA) << " and "
+ << NVVM::stringifyMMATypes(typeB);
+ }
+ if (getShape().getK() != 256) {
+ return emitOpError() << "shape 'k' must be 256 "
+ << "for input type "
+ << NVVM::stringifyMMATypes(typeA);
+ }
+ break;
+ default:
+ return emitOpError() << "Unsupported input type "
+ << NVVM::stringifyMMATypes(typeA) << " and "
+ << NVVM::stringifyMMATypes(typeB);
+ }
+
+ // Check N
+ SmallVector<int> allowedNShapesF16 = {8, 16, 24, 32, 40, 48, 56, 64,
+ 72, 80, 88, 96, 104, 112, 120, 128,
+ 136, 144, 152, 160, 168, 176, 184, 192,
+ 200, 208, 216, 224, 232, 240, 248, 256};
+ SmallVector<int> allowedNShapesU8S8B1 = {8, 16, 24, 32, 48, 64,
+ 80, 96, 112, 128, 144, 160,
+ 176, 192, 208, 224, 240, 256};
+
+ bool validGEMMType = false;
+ // f16 += f16 * f16
+ if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) {
+ if (!llvm::any_of(allowedNShapesF16,
+ [&](int n) { return getShape().getN() == n; })) {
+ return emitOpError() << "has input type "
+ << NVVM::stringifyMMATypes(typeA) << " n is set to "
+ << getShape().getN() << ", it is not supported.";
+ }
+ validGEMMType = true;
+ }
+ // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16
+ if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 ||
+ typeA == mlir::NVVM::MMATypes::tf32 ||
+ typeA == mlir::NVVM::MMATypes::f16)) {
+ if (!llvm::any_of(allowedNShapesF16,
+ [&](int n) { return getShape().getN() == n; })) {
+ return emitOpError() << "has input type "
+ << NVVM::stringifyMMATypes(typeA) << " n is set to "
+ << getShape().getN() << ", it is not supported.";
+ }
+ validGEMMType = true;
+ }
+ // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1
+ if (outputType.isInteger(32) &&
+ (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 ||
+ typeA == mlir::NVVM::MMATypes::b1)) {
+ if (!llvm::any_of(allowedNShapesU8S8B1,
+ [&](int n) { return getShape().getN() == n; })) {
+ return emitOpError() << "has input type "
+ << NVVM::stringifyMMATypes(typeA) << " n is set to "
+ << getShape().getN() << ", it is not supported.";
+ }
+ validGEMMType = true;
+ }
+
+ if (!validGEMMType) {
+ return emitOpError() << outputType
+ << " += " << NVVM::stringifyMMATypes(typeA) << " * "
+ << NVVM::stringifyMMATypes(typeB)
+ << ", it is not supported.";
+ }
+
+ // Check transpose is needed from the given layouts. It is only
+ // supported for bf16 or f16.
+ if ((typeA != mlir::NVVM::MMATypes::f16 &&
+ typeA != mlir::NVVM::MMATypes::bf16) &&
+ (getLayoutA() == mlir::NVVM::MMALayout::col ||
+ getLayoutB() == mlir::NVVM::MMALayout::col)) {
+ return emitOpError()
+ << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
+ << " and layout_b = " << stringifyMMALayout(getLayoutB())
+ << " for input types " << stringifyMMATypes(typeA) << " and "
+ << stringifyMMATypes(typeB)
+ << " requires transpose. However, this is only supported for: "
+ << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
+ << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
+ }
+
+ // Check number of result registers
+ int expectedOutput;
+ if (outputType.isF32() || outputType.isInteger(32))
+ expectedOutput = getShape().getN() / 2;
+ if (outputType.isF16())
+ expectedOutput = getShape().getN() / 4;
+ if (outputSize != expectedOutput) {
+ return emitOpError() << "results " << expectedOutput
+ << ", however output struct has " << outputSize
+ << " elements";
+ }
+ // Check satfinite is set. It is only for s32 accumulator
+ if (!outputType.isInteger(32) &&
+ getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+ NVVM::MMAIntOverflow::satfinite) {
+ return emitOpError()
+ << " `satfinite` can be only used with s32 accumulator, however "
+ "the current accumulator is "
+ << outputType;
+ }
+
+ return success();
+}
+
+std::string NVVM::WgmmaMmaSyncOp::getPtx() {
+
+ int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
+ bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
+ getTypeA() == mlir::NVVM::MMATypes::bf16;
+
+ Value outValue = getResults() ? getResults() : getInouts();
+ auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+ Type outputType = stype.getBody().front();
+ std::string outputTypeName;
+ if (outputType.isF16())
+ outputTypeName = "f16";
+ if (outputType.isF32())
+ outputTypeName = "f32";
+ else if (outputType.isInteger(32))
+ outputTypeName = "s32";
+ int expectedOutputRegisters;
+ if (outputType.isF32() || outputType.isInteger(32))
+ expectedOutputRegisters = getShape().getN() / 2;
+ if (outputType.isF16())
+ expectedOutputRegisters = getShape().getN() / 4;
+
+ std::string ptx;
+ llvm::raw_string_ostream ss(ptx);
+
+ ss << "{\n"
+ ".reg .pred p;\n"
+ "setp.ne.b32 p, $"
+ << (expectedOutputRegisters + 2)
+ << ", 0;\n"
+ "wgmma.mma_async.sync.aligned.m"
+ << m << "n" << n << "k" << k << "." << outputTypeName << "."
+ << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB());
+ if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+ NVVM::MMAIntOverflow::satfinite)
+ ss << ".satfinite";
+ ss << " {";
+ int regCnt = 0;
+ for (; regCnt < expectedOutputRegisters; ++regCnt) {
+ ss << "$" << regCnt;
+ if (regCnt != expectedOutputRegisters - 1)
+ ss << ", ";
+ }
+
+ ss << "},";
+ ss << " $" << (expectedOutputRegisters) << ","
+ << " $" << (expectedOutputRegisters + 1) << ","
+ << " p";
+ if (!outputType.isInteger(32)) {
+ ss << ", $" << (expectedOutputRegisters + 3) << ", $"
+ << (expectedOutputRegisters + 4);
+ }
+ // Don't add transpose parameters unless needed.
+ if (isF16) {
+ ss << ", $" << (expectedOutputRegisters + 5) << ", $"
+ << (expectedOutputRegisters + 6);
+ }
+ ss << ";\n"
+ << "}\n";
+ ss.flush();
+ return ptx;
+}
+
+void NVVM::WgmmaMmaSyncOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ Value outValue = getResults() ? getResults() : getInouts();
+ auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+ Type outputType = stype.getBody().front();
+ bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
+ getTypeA() == mlir::NVVM::MMATypes::bf16;
+ if (getResults())
+ asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
+ if (getInouts())
+ asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
+ asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
+ mlir::NVVM::PTXRegisterMod::Read});
+ if (!outputType.isInteger(32)) {
+ asmValues.push_back(
+ {makeConstantI32(rewriter,
+ getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+ mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back(
+ {makeConstantI32(rewriter,
+ getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+ mlir::NVVM::PTXRegisterMod::Read});
+ }
+ if (isF16) {
+ asmValues.push_back(
+ {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
+ mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back(
+ {makeConstantI32(rewriter, static_cast<int>(getLayoutB())),
+ mlir::NVVM::PTXRegisterMod::Read});
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
new file mode 100644
index 00000000000000..7f8fa5267c343f
--- /dev/null
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file -verify-diagnostics %s
+
+!mat64f32 = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32)>
+func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
+ %result = llvm.mlir.undef : !mat64f32
+ // expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 128, k = 16>,
+ D [%result, <zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !mat64f32 -> !mat64f32
+ return %res : !mat64f32
+}
+
+// -----
+
+func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 16, k = 16>,
+ D [%result, <zero>, <satfinite>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ return
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error @+1 {{shape 'm' must be 64}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 32, n = 16, k = 16>,
+ D [%result, <zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ return
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)>
+ // expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 16, k = 16>,
+ D [%result, <zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)>
+ -> !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)>
+ return
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error @+1 {{op shape 'k' must be 16 for input type f16}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 16, k = 3>,
+ D [%result, <zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ return
+}
+
+// -----
+
+func.func @wgmma_transpose(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 16, k = 8>,
+ D [%result, <zero>],
+ A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ return
+}
+
+// -----
+
+func.func @wgmma_transpose(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16)>
+ // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 16, k = 8>,
+ D [%result, <zero>],
+ A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
+ :!llvm.struct<(f16, f16, f16, f16)>
+ -> !llvm.struct<(f16, f16, f16, f16)>
+ return
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
+ %result = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)>
+ // expected-error @+1 {{input struct and result struct must be the same type}}
+ %res = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 16>,
+ D [%result, <zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !llvm.struct<(i32, i32, i32, i32)>
+ -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 9ba913b9d3ea2a..342fc2fc0430b1 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
@@ -103,3 +103,168 @@ func.func @wgmma_execute() {
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
return
}
+
+// -----
+
+!mat64f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_f16_f16(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
+ // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+ // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[A5:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19, $20, $21, $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[DESCa:.+]] = llvm.add %[[ARG0]], %[[C2]] : i64
+ // CHECK: %[[DESCb:.+]] = llvm.add %[[ARG1]], %[[C2]] : i64
+ // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19, $20, $21, $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ %result = llvm.mlir.undef : !mat64f32
+ %result1 = nvvm.wgmma.mma_async
+ %descA, %descB,
+ #nvvm.shape<m = 64, n = 32, k = 16>,
+ D [%result, #nvvm.wgmma_scale_out<zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ :!mat64f32 -> !mat64f32
+ %c2 = arith.constant 2 : i64
+ %descAnext = arith.addi %descA, %c2 : i64
+ %descBnext = arith.addi %descB, %c2 : i64
+ %result2 = nvvm.wgmma.mma_async
+ %descAnext, %descBnext,
+ #nvvm.shape<m = 64, n = 32, k = 16>,
+ D [%result1, #nvvm.wgmma_scale_out<zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !mat64f32 -> !mat64f32
+ return %result2 : !mat64f32
+}
+
+// -----
+
+!mat16i32 = !llvm.struct<(i32, i32, i32, i32)>
+
+// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
+ %result = llvm.mlir.undef : !mat16i32
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
+ %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [%result, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+ : !mat16i32 -> !mat16i32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [%result1, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+ : !mat16i32 -> !mat16i32
+ %result3 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [%result2, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+ : !mat16i32 -> !mat16i32
+ return %result3 : !mat16i32
+}
+
+// CHECK-LABEL: @wgmma_s32_u8_u8(
+ // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
+ %result = llvm.mlir.undef : !mat16i32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [%result, #nvvm.wgmma_scale_out<one>],
+ A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+ : !mat16i32 -> !mat16i32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [%result1, #nvvm.wgmma_scale_out<one>],
+ A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+ : !mat16i32 -> !mat16i32
+ %result3 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [%result2, #nvvm.wgmma_scale_out<one>],
+ A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+ : !mat16i32 -> !mat16i32
+ return %result3 : !mat16i32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_tf32_tf32
+func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
+ // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35, $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ %result = llvm.mlir.undef : !mat32f32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 64, k = 8>,
+ D [%result, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ : !mat32f32 -> !mat32f32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB,
+ #nvvm.shape<m = 64, n = 64, k = 8>,
+ D [%result1, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ : !mat32f32 -> !mat32f32
+ return %result2 : !mat32f32
+}
More information about the Mlir-commits
mailing list