[Mlir-commits] [mlir] [mlir][nvvm] Introduce `nvvm.stmatrix` Op (PR #69467)

Guray Ozen llvmlistbot at llvm.org
Thu Oct 19 01:26:18 PDT 2023


================
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
+  Arguments<(ins LLVM_i8Ptr_shared:$ptr, 
+                 Variadic<I32>:$sources, 
+                 MMALayoutAttr:$layout)> {
+  let summary = "cooperative matrix store";
+  let description = [{
+    Collectively store one or more matrices across all threads in a warp to the
+    location indicated by the address operand $ptr in shared memory.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+  }];
+  
+  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int d = getSources().size();
+      std::string ptx = "stmatrix.sync.aligned";
+      ptx += ".x" + std::to_string(d);
+      if (getLayout() == NVVM::MMALayout::col)
+        ptx += ".trans";
+      if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+      if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+      if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
----------------
grypp wrote:

The verifier catches that actually

https://github.com/llvm/llvm-project/pull/69467


More information about the Mlir-commits mailing list