[Mlir-commits] [mlir] 3be7c28 - [mlir][NVVM] Add support for nvvm mma.sync ops
Thomas Raoux
llvmlistbot at llvm.org
Fri Mar 25 10:32:07 PDT 2022
Author: Christopher Bate
Date: 2022-03-25T17:28:05Z
New Revision: 3be7c2891798025e524fce1b7cdaa27aee6c9816
URL: https://github.com/llvm/llvm-project/commit/3be7c2891798025e524fce1b7cdaa27aee6c9816
DIFF: https://github.com/llvm/llvm-project/commit/3be7c2891798025e524fce1b7cdaa27aee6c9816.diff
LOG: [mlir][NVVM] Add support for nvvm mma.sync ops
This patch adds MLIR NVVM support for the various NVPTX `mma.sync`
operations. There are a number of possible data type, shape,
and other attribute combinations supported by the operation, so a
custom assebmly format is added and attributes are inferred where
possible.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D122410
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index f6513a242d3fc..c44d792abf01d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -35,6 +35,8 @@ set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(NVVMOpsStructs.h.inc -gen-struct-attr-decls)
+mlir_tablegen(NVVMOpsStructs.cpp.inc -gen-struct-attr-defs)
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
add_public_tablegen_target(MLIRNVVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index de942f6fb4d31..7be5f49c0326a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -21,6 +21,7 @@
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
+#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.h.inc"
namespace mlir {
namespace NVVM {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d65b525eacf6f..f9d32f480888a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -195,18 +195,6 @@ def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
let assemblyFormat = "$n attr-dict";
}
-def NVVM_MmaOp :
- NVVM_Op<"mma.sync">,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins Variadic<LLVM_Type>:$args)> {
- string llvmBuilder = [{
- $res = createIntrinsicCall(
- builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
- }];
- let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
- let hasVerifier = 1;
-}
-
/// Helpers to instantiate
diff erent version of wmma intrinsics.
/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
/// combinations of the intrinsics.
@@ -296,6 +284,7 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
// Creates list of valid combinations of fragments. This is a subset of what
// llvm supports and can be extended as needed.
class NVVM_MMA_OPS {
+ // "wmma" operations
list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
[GEOM<16, 16, 8>],
["tf32"], [], ["f32"], []>.ret;
@@ -324,6 +313,32 @@ class NVVM_MMA_OPS {
// Separate A/B/C fragments (loads) from D (stores).
list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
+
+ // "mma_sync" operations
+ list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
+ [GEOM<16,8,4>, GEOM<16,8,8>],
+ ["tf32"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
+ [GEOM<16,8,16>, GEOM<16,8,8>],
+ ["bf16"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
+ [GEOM<8,8,4>],
+ ["f64"], [], ["f64"], []>.ret;
+ list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+ [GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
+ [GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
+ [GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
+ [GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>],
+ ["b1"], [], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
+ tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
+ fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -405,6 +420,150 @@ class MMA_MMA_INTR<string opName> {
string id = !foldl("", f, acc, el, acc # "\n" # el);
}
+/// Enum attribute for binary (b1) MMA operation type
+def MMAB1OpNone : I32EnumAttrCase<"none", 0>;
+def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>;
+def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>;
+def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations",
+ [MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def MMAB1OpAttr : EnumAttr<NVVM_Dialect, MMAB1Op, "mma_b1op"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// Enum attribute type for the overflow behavior of MMA integer operations
+def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>;
+def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>;
+def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
+ [MMAIntOverflowSat, MMAIntOverflowWrap]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// Attribute to hold the MMA shape
+def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [
+ StructFieldAttr<"m", I32Attr>,
+ StructFieldAttr<"n", I32Attr>,
+ StructFieldAttr<"k", I32Attr>
+ ]> {
+ let summary = "Attribute for MMA operation shape.";
+}
+
+// Returns true if this combination of layout/satf for MMA ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
+ // MMA ops check both layouts.
+ string layout = layout_a # ":" # layout_b;
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ // gcd is a shortcut used to identify instructions that depend on
+ // geom+frag_c+frag_d.
+ string gcd = geom # ":" # c_type # d_type;
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !ne(a_type, "s8"),
+ !ne(a_type, "u8"),
+ !ne(a_type, "s4"),
+ !ne(a_type, "u4")): false,
+
+ // m8n8k4 has no C=f32 D=f16 variant.
+ !eq(gcd, "m8n8k4:f32f16"): false,
+
+ // only m8n8k4 for f16 does not require row:col layout
+ !and(!ne(layout, "row:col"),
+ !or(!ne(geom, "m8n8k4"),
+ !ne(a_type, "f16"))) : false,
+
+ // m16n8k8 requires A and B to be the same type and C and D to be the same
+ // type.
+ !and(!eq(geom, "m16n8k8"),
+ !or(!ne(a_type, b_type),
+ !ne(c_type, d_type))): false,
+
+ // m16n8k8 requires C and D to be the same type.
+ !and(!eq(geom, "m16n8k8"),
+ !ne(c_type, d_type)): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+// Returns a list of operation suffixes corresponding to possible b1
+// multiply-and-accumulate operations for all fragments which have a
+// b1 type. For all other fragments, the list returned holds a list
+// containing the empty string.
+class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
+ list<string> ret = !cond(
+ !eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"],
+ true: [""]
+ );
+}
+
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma"
+ # !if(!ne(b1op, ""), "_" # b1op, "")
+ # "_" # A.geom
+ # "_" # ALayout
+ # "_" # BLayout
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
+/// Helper to create the mapping between the configuration and the mma.sync
+/// intrinsic enum value.
+class MMA_SYNC_INTR {
+ list<list<list<list<list<string>>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sync_ops,
+ !foreach(layoutA, ["row", "col"],
+ !foreach(layoutB, ["row", "col"],
+ !foreach (sat, [0, 1],
+ !foreach (b1op, NVVM_MMA_B1OPS<op>.ret,
+ !if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]],
+ layoutA, layoutB, sat>.ret,
+ "if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && "
+ " m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
+ " && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
+ # op[1].ptx_elt_type # "\" == eltypeB && "
+ # " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
+ # " \"" # op[3].ptx_elt_type # "\" == eltypeD "
+ # " && (sat.hasValue() ? " # sat # " == static_cast<int>(*sat) : true)"
+ # !if(!ne(b1op, ""), " && (b1Op.hasValue() ? MMAB1Op::" # b1op # " == b1Op.getValue() : true)", "") # ")\n"
+ # " return " #
+ MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // b1op
+ ) // sat
+ ) // layoutB
+ ) // layoutA
+ ); // all_mma_sync_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]],
+ !foldl([[[[""]]]], cond0, acc, el,
+ !listconcat(acc, el)),
+ acc1, el1, !listconcat(acc1, el1));
+ list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
def MMALayoutRow : I32EnumAttrCase<"row", 0>;
def MMALayoutCol : I32EnumAttrCase<"col", 1>;
@@ -418,13 +577,24 @@ def MMALayoutAttr : EnumAttr<NVVM_Dialect, MMALayout, "mma_layout"> {
let assemblyFormat = "`<` $value `>`";
}
+/// Enum attribute of the
diff erent PTX element types used for MMA operands.
def MMATypeF16 : I32EnumAttrCase<"f16", 0>;
def MMATypeF32 : I32EnumAttrCase<"f32", 1>;
def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>;
+def MMATypeU8 : I32EnumAttrCase<"u8", 3>;
+def MMATypeS8 : I32EnumAttrCase<"s8", 4>;
+def MMATypeS32 : I32EnumAttrCase<"s32", 5>;
+def MMATypeB1 : I32EnumAttrCase<"b1", 6>;
+def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
+def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
+def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
+def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
-/// Enum attribute of the
diff erent matrix types.
def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
- [MMATypeF16, MMATypeF32, MMATypeTF32]> {
+ [MMATypeF16, MMATypeF32, MMATypeTF32,
+ MMATypeBF16, MMATypeS8, MMATypeU8,
+ MMATypeS32, MMATypeS4, MMATypeU4,
+ MMATypeB1, MMATypeF64]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -678,4 +848,141 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
let hasVerifier = 1;
}
+def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
+
+ let summary = "cooperative matrix-multiply and accumulate";
+
+ let description = [{
+ The `nvvm.mma.sync` operation collectively performs the operation
+ `D = matmul(A, B) + C` using all threads in a warp.
+
+ All the threads in the warp must execute the same `mma.sync` operation.
+
+ For each possible multiplicand PTX data type, there are one or more possible
+ instruction shapes given as "mMnNkK". The below table describes the posssibilities
+ as well as the types required for the operands. Note that the data type for
+ C (the accumulator) and D (the result) can vary independently when there are
+ multiple possibilities in the "C/D Type" column.
+
+ When an optional attribute cannot be immediately inferred from the types of
+ the operands and the result during parsing or validation, an error will be
+ raised.
+
+ `b1Op` is only relevant when the binary (b1) type is given to
+ `multiplicandDataType`. It specifies how the multiply-and-acumulate is
+ performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.
+
+ `intOverflowBehavior` is only relevant when the `multiplicandType` attribute
+ is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
+ in the accumulator. When the attribute is `satfinite`, the accumulator values
+ are clamped in the int32 range on overflow. This is the default behavior.
+ Alternatively, accumulator behavior `wrapped` can also be specified, in
+ which case overflow wraps from one end of the range to the other.
+
+ `layoutA` and `layoutB` are required and should generally be set to
+ `#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
+ combinations are possible for certain layouts according to the table below.
+
+ ```
+ | A/B Type | Shape | ALayout | BLayout | A Type | B Type | C/D Type |
+ |----------|-----------|---------|---------|----------|----------|-------------------|
+ | f64 | .m8n8k4 | row | col | 1x f64 | 1x f64 | 2x f64 |
+ | f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
+ | | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
+ | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+ | bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
+ | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+ | tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
+ | | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
+ | u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |
+ | | .m16n8k16 | row | col | 2x i32 | 1x i32 | 4x i32 |
+ | | .m16n8k32 | row | col | 4x i32 | 2x i32 | 4x i32 |
+ | u4/s4 | .m8n8k32 | row | col | 1x i32 | 1x i32 | 2x i32 |
+ | | m16n8k32 | row | col | 2x i32 | 1x i32 | 4x i32 |
+ | | m16n8k64 | row | col | 4x i32 | 2x i32 | 4x i32 |
+ | b1 | m8n8k128 | row | col | 1x i32 | 1x i32 | 2x i32 |
+ | | m16n8k128 | row | col | 2x i32 | 1x i32 | 4x i32 |
+ ```
+
+
+ Example:
+ ```mlir
+
+ %128 = nvvm.mma.sync A[%120, %121, %122, %123]
+ B[%124, %125]
+ C[%126, %127]
+ {layoutA = #nvvm.mma_layout<row>,
+ layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>)
+ -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ ```
+ }];
+
+ let results = (outs LLVM_AnyStruct:$res);
+ let arguments = (ins NVVM_MMAShapeAttr:$shape,
+ OptionalAttr<MMAB1OpAttr>:$b1Op,
+ OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
+ MMALayoutAttr:$layoutA,
+ MMALayoutAttr:$layoutB,
+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+ Variadic<LLVM_Type>:$operandA,
+ Variadic<LLVM_Type>:$operandB,
+ Variadic<LLVM_Type>:$operandC);
+
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ llvm::Optional<MMAB1Op> b1Op,
+ llvm::Optional<MMAIntOverflow> sat,
+ mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
+ llvm::StringRef layoutA = stringifyEnum(layoutAEnum);
+ llvm::StringRef layoutB = stringifyEnum(layoutBEnum);
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+ llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
+ }],
+ MMA_SYNC_INTR<>.id, [{
+ return 0;
+ }
+
+ static Optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
+ bool isAccumulator);
+
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+ }]);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "ArrayRef<int64_t>":$shape, "Optional<MMAB1Op>":$b1Op,
+ "Optional<MMAIntOverflow>":$intOverflow,
+ "Optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "Optional<std::array<MMALayout, 2>>":$multiplicandLayouts)>
+ ];
+
+ string llvmBuilder = [{
+ auto operands = moduleTranslation.lookupValues(opInst.getOperands());
+ auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
+ $shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(),
+ $b1Op, $intOverflowBehavior,
+ $layoutA, $layoutB,
+ $multiplicandAPtxType.getValue(),
+ $multiplicandBPtxType.getValue(),
+ op.accumPtxType(),
+ op.resultPtxType());
+
+ $res = createIntrinsicCall(
+ builder, intId, operands);
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
#endif // NVVMIR_OPS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9b909a2b1a5bb..c2c1eb49e1726 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -34,6 +34,7 @@ using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
+#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.cpp.inc"
//===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops
@@ -69,47 +70,455 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+// Given the element type of an operand and whether or not it is an accumulator,
+// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
+// operand's element type.
+Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
+ bool isAccumulator) {
+ auto half2Type =
+ LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+ if (operandElType.isF64())
+ return NVVM::MMATypes::f64;
+ if (operandElType.isF16() || operandElType == half2Type)
+ return NVVM::MMATypes::f16;
+ if (operandElType.isF32())
+ return NVVM::MMATypes::f32;
+ if (operandElType.isa<IntegerType>()) {
+ if (isAccumulator)
+ return NVVM::MMATypes::s32;
+ return llvm::None;
+ }
+
+ if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
+ if (structType.getBody().empty())
+ return llvm::None;
+ return inferOperandMMAType(structType.getBody()[0], isAccumulator);
+ }
+
+ return llvm::None;
+}
+
+static bool isInt4PtxType(MMATypes type) {
+ return (type == MMATypes::u4 || type == MMATypes::s4);
+}
+
+static bool isInt8PtxType(MMATypes type) {
+ return (type == MMATypes::u8 || type == MMATypes::s8);
+}
+
+static bool isIntegerPtxType(MMATypes type) {
+ return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
+ type == MMATypes::s32;
+}
+
+MMATypes MmaOp::accumPtxType() {
+ Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccum=*/true);
+ assert(val.hasValue() && "accumulator PTX type should always be inferrable");
+ return val.getValue();
+}
+
+MMATypes MmaOp::resultPtxType() {
+ Optional<mlir::NVVM::MMATypes> val =
+ inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
+ assert(val.hasValue() && "result PTX type should always be inferrable");
+ return val.getValue();
+}
+
+void MmaOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 3> frags{
+ OperandFragment("A", multiplicandAPtxTypeAttrName()),
+ OperandFragment("B", multiplicandBPtxTypeAttrName()),
+ OperandFragment("C", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
+
+ for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == 0) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ Optional<MMATypes> inferredType =
+ inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ auto printMmaOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "] ";
+ };
+
+ for (const auto &frag : frags) {
+ printMmaOperand(frag);
+ }
+
+ p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+ // Print the types of the operands and result.
+ p << " : "
+ << "(";
+ llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+ frags[1].regs[0].getType(),
+ frags[2].regs[0].getType()},
+ p);
+ p << ")";
+ p.printArrowTypeList(TypeRange{this->res().getType()});
+}
+
+void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ ArrayRef<int64_t> shape, Optional<MMAB1Op> b1Op,
+ Optional<MMAIntOverflow> intOverflow,
+ Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ Type i32 = builder.getIntegerType(32);
+ result.addAttribute(
+ "shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]),
+ builder.getIntegerAttr(i32, shape[1]),
+ builder.getIntegerAttr(i32, shape[2]), ctx));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+
+ if (multiplicandPtxTypes.hasValue()) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (multiplicandLayouts.hasValue()) {
+ result.addAttribute("layoutA",
+ MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
+ result.addAttribute("layoutB",
+ MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
+ } else {
+ result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
+ result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
+ }
+
+ if (intOverflow.hasValue())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+ if (b1Op.hasValue())
+ result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size())}));
+}
+
+// <operation> :=
+// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
+// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
+// `->` type($res)
+ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ Optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 4> frags;
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+ };
+
+ // Parse the operand segments.
+ if (parseMmaOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaOperand("B", frags[1]).failed())
+ return failure();
+ if (parseMmaOperand("C", frags[2]).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse the type specification and resolve operands.
+ SmallVector<Type, 3> operandTypes;
+ if (failed(parser.parseColon()))
+ return failure();
+ if (failed(parser.parseLParen()))
+ return failure();
+ if (failed(parser.parseTypeList(operandTypes)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ if (operandTypes.size() != 3)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expected one type for each operand segment but got " +
+ Twine(operandTypes.size()) + " types");
+ for (auto iter : llvm::enumerate(operandTypes)) {
+ auto &frag = frags[iter.index()];
+ frag.regTypes.resize(frag.regs.size(), iter.value());
+ if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+ parser.getNameLoc(), result.operands)))
+ return failure();
+ frag.elemtype =
+ inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
+ }
+
+ Type resultType;
+ parser.parseArrow();
+ parser.parseType(resultType);
+ frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
+
+ std::array<StringRef, 2> names{"multiplicandAPtxType",
+ "multiplicandBPtxType"};
+ for (unsigned idx = 0; idx < names.size(); idx++) {
+ const auto &frag = frags[idx];
+ Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+ if (!frag.elemtype.hasValue() && !attr.hasValue()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "attribute " + names[idx] +
+ " is not provided explicitly and cannot be inferred");
+ }
+ if (!attr.hasValue())
+ result.addAttribute(
+ names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+ }
+
+ result.addTypes(resultType);
+ if (!namedAttributes.empty())
+ result.addAttributes(namedAttributes);
+ result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ }));
+ return success();
+}
+
LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
+ auto i32Ty = IntegerType::get(context, 32);
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
- auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
- context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
-
- auto operandTypes = getOperandTypes();
- if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
- operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
- f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
- f32Ty}) {
- return emitOpError("expected operands to be 4 <halfx2>s followed by either "
- "4 <halfx2>s or 8 floats");
+
+ auto s32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+ auto f32x8StructTy =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ auto f16x2x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+ auto f32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+ auto s32x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+ std::array<int64_t, 3> mmaShape{shapeAttr().m().getInt(),
+ shapeAttr().n().getInt(),
+ shapeAttr().k().getInt()};
+
+ // These variables define the set of allowed data types for matrices A, B, C,
+ // and result.
+ using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+ using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+ AllowedShapes allowedShapes;
+ AllowedTypes expectedA;
+ AllowedTypes expectedB;
+ AllowedTypes expectedC;
+ SmallVector<Type> expectedResult;
+
+ // When M = 16, we just need to calculate the number of 8xk tiles, where
+ // k is a factor that depends on the data type.
+ if (mmaShape[0] == 16) {
+ int64_t kFactor;
+ Type multiplicandFragType;
+ switch (multiplicandAPtxType().getValue()) {
+ case MMATypes::tf32:
+ kFactor = 4;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {i32Ty, i32Ty, i32Ty, i32Ty}));
+ break;
+ case MMATypes::f16:
+ case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = f16x2Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ break;
+ case MMATypes::s4:
+ case MMATypes::u4:
+ kFactor = 32;
+ break;
+ case MMATypes::b1:
+ kFactor = 128;
+ break;
+ case MMATypes::s8:
+ case MMATypes::u8:
+ kFactor = 16;
+ break;
+ default:
+ return emitError("invalid shape or multiplicand type: " +
+ stringifyEnum(multiplicandAPtxType().getValue()));
+ }
+
+ if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
+ expectedResult.push_back(s32x4StructTy);
+ expectedC.emplace_back(4, i32Ty);
+ multiplicandFragType = i32Ty;
+ } else {
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ }
+
+ int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
+ int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+ expectedA.emplace_back(unitA, multiplicandFragType);
+ expectedB.emplace_back(unitB, multiplicandFragType);
+ allowedShapes.push_back({16, 8, kFactor});
+ allowedShapes.push_back({16, 8, kFactor * 2});
}
- if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
- return emitOpError("expected result type to be a struct of either 4 "
- "<halfx2>s or 8 floats");
+
+ // In the M=8 case, there is only 1 possible case per data type.
+ if (mmaShape[0] == 8) {
+ if (multiplicandAPtxType().getValue() == MMATypes::f16) {
+ expectedA.emplace_back(2, f16x2Ty);
+ expectedB.emplace_back(2, f16x2Ty);
+ expectedResult.push_back(f16x2x4StructTy);
+ expectedResult.push_back(f32x8StructTy);
+ expectedC.emplace_back(4, f16x2Ty);
+ expectedC.emplace_back(8, f32Ty);
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (multiplicandAPtxType().getValue() == MMATypes::f64) {
+ Type f64Ty = Float64Type::get(context);
+ expectedA.emplace_back(1, f64Ty);
+ expectedB.emplace_back(1, f64Ty);
+ expectedC.emplace_back(2, f64Ty);
+ // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+ expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(2, f64Ty)));
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
+ expectedA.push_back({i32Ty});
+ expectedB.push_back({i32Ty});
+ expectedC.push_back({i32Ty, i32Ty});
+ expectedResult.push_back(s32x2StructTy);
+ if (isInt4PtxType(multiplicandAPtxType().getValue()))
+ allowedShapes.push_back({8, 8, 32});
+ if (isInt8PtxType(multiplicandAPtxType().getValue()))
+ allowedShapes.push_back({8, 8, 16});
+ if (multiplicandAPtxType().getValue() == MMATypes::b1)
+ allowedShapes.push_back({8, 8, 128});
+ }
}
- auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
- auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
+ std::string errorMessage;
+ llvm::raw_string_ostream errorStream(errorMessage);
- if (!(alayout && blayout) ||
- !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
- !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
- return emitOpError("alayout and blayout attributes must be set to either "
- "\"row\" or \"col\"");
+ // Check that we matched an existing shape/dtype combination.
+ if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+ !llvm::any_of(allowedShapes,
+ [&](const auto &allowed) { return allowed == mmaShape; })) {
+ errorStream << "unimplemented variant for MMA shape <";
+ llvm::interleaveComma(mmaShape, errorStream);
+ errorStream << ">";
+ return emitOpError(errorMessage);
}
- if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
- f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
- f32Ty} &&
- getType() == f32x8StructTy && alayout.getValue() == "row" &&
- blayout.getValue() == "col") {
- return success();
+ // Verify the operand types for segments of A, B, and C operands.
+ std::array<StringRef, 3> operandNames{"A", "B", "C"};
+ for (const auto &iter : llvm::enumerate(
+ SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+ auto spec = this->getODSOperandIndexAndLength(iter.index());
+ SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+ operand_type_begin() + spec.first +
+ spec.second);
+ bool match =
+ llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
+ return typeSet == operandTySeg;
+ });
+
+ if (!match) {
+ errorStream << "Could not match types for the "
+ << operandNames[iter.index()]
+ << " operands; expected one of ";
+ for (const auto &x : iter.value()) {
+ errorStream << x.size() << "x" << x[0] << " ";
+ }
+ errorStream << "but got ";
+ llvm::interleaveComma(operandTySeg, errorStream);
+ return emitOpError(errorStream.str());
+ }
+ }
+
+ // Check the result type
+ if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+ return expectedResultType == getResult().getType();
+ })) {
+ errorStream
+ << "Could not match allowed types for the result; expected one of ";
+ llvm::interleaveComma(expectedResult, errorStream);
+ errorStream << " but got " << getResult().getType();
+ return emitOpError(errorStream.str());
+ }
+
+ // Ensure that binary MMA variants have a b1 MMA operation defined.
+ if (multiplicandAPtxType() == MMATypes::b1 && !b1Op().hasValue()) {
+ return emitOpError("op requires " + b1OpAttrName().strref() + " attribute");
}
- return emitOpError("unimplemented mma.sync variant");
+
+ // Ensure int4/int8 MMA variants specify the accum overflow behavior
+ // attribute.
+ if (isInt4PtxType(*multiplicandAPtxType()) ||
+ isInt8PtxType(*multiplicandAPtxType())) {
+ if (!intOverflowBehavior().hasValue())
+ return emitOpError("op requires " +
+ intOverflowBehaviorAttrName().strref() + " attribute");
+ }
+
+ return success();
}
LogicalResult ShflOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6c7e3ae5712d7..cc4ae43b7ae3e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -514,12 +514,13 @@ func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i3
// -----
-func @nvvm_invalid_mma_0(%a0 : f16, %a1 : vector<2xf16>,
+func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (f16, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error at +1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+ {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
@@ -529,8 +530,9 @@ func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
+ // expected-error at +1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+ {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
}
@@ -540,8 +542,9 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{alayout and blayout attributes must be set to either "row" or "col"}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // expected-error at +1 {{op requires attribute 'layoutA'}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+ {shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
@@ -549,55 +552,23 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
- // expected-error at +1 {{unimplemented mma.sync variant}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-}
-
-// -----
-
-func @nvvm_invalid_mma_4(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
- %b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{unimplemented mma.sync variant}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ // expected-error at +1 {{unimplemented variant for MMA shape <8, 8, 16>}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// -----
-func @nvvm_invalid_mma_5(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
- %b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{unimplemented mma.sync variant}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-}
-
-// -----
-
-func @nvvm_invalid_mma_6(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
- %b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{invalid kind of type specified}}
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-}
-
-// -----
-
-func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
- %b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // expected-error at +1 {{op requires one result}}
- %0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> (!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, i32)
- llvm.return %0#0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ // expected-error at +1 {{op requires b1Op attribute}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
+ shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
}
// -----
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 2b191541c3b02..c2e1db76f251c 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -66,15 +66,164 @@ func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
llvm.return %0 : i32
}
-func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
+func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
- // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
+ // CHECK: nvvm.mma.sync
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
+func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
+ %c0 : i32, %c1 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+ %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+ shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+ llvm.return %0 : !llvm.struct<(i32, i32)>
+}
+
+func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+ shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
+ b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
+ b1Op = #nvvm.mma_b1op<xor_popc>,
+ shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m8n8k128_b1_b1
+func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+ %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
+ b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4
+func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+ shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+// CHECK-LABEL: @nvvm_wmma_load_tf32
func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: nvvm.wmma.load {{.*}} {eltype = #nvvm.mma_type<tf32>, frag = #nvvm.mma_frag<a>, k = 8 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
%0 = nvvm.wmma.load %arg0, %arg1
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index b62913b7c2737..ad73a295359df 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -88,17 +88,124 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
llvm.return %3 : i32
}
-llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
+llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
- %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
+llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16
+ %0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// f32 return type, f16 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// f16 return type, f32 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// f32 return type, f32 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+ shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+ shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
+ b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+ shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
+ %b0 : f64,
+ %c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> {
+ // CHECK: call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64
+ %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (f64, f64, f64) -> !llvm.struct<(f64, f64)>
+ llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
// in the LLVM NVPTX backend.
+// CHECK-LABEL: @gpu_wmma_load_op
llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
%0 = nvvm.wmma.load %arg0, %arg1
More information about the Mlir-commits
mailing list