[Mlir-commits] [mlir] [mlir][nvvm] Introduce `nvvm.stmatrix` Op (PR #69467)
Guray Ozen
llvmlistbot at llvm.org
Thu Oct 19 01:26:18 PDT 2023
================
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
+ Arguments<(ins LLVM_i8Ptr_shared:$ptr,
+ Variadic<I32>:$sources,
+ MMALayoutAttr:$layout)> {
+ let summary = "cooperative matrix store";
+ let description = [{
+ Collectively store one or more matrices across all threads in a warp to the
+ location indicated by the address operand $ptr in shared memory.
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+ }];
+
+ let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int d = getSources().size();
+ std::string ptx = "stmatrix.sync.aligned";
+ ptx += ".x" + std::to_string(d);
+ if (getLayout() == NVVM::MMALayout::col)
+ ptx += ".trans";
+ if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+ if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+ if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
----------------
grypp wrote:
The verifier catches that actually
https://github.com/llvm/llvm-project/pull/69467
More information about the Mlir-commits
mailing list