[Mlir-commits] [mlir] 5316d19 - [mlir][nvvm] Introduce `nvvm.stmatrix` Op (#69467)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 01:26:33 PDT 2023


Author: Guray Ozen
Date: 2023-10-19T10:26:28+02:00
New Revision: 5316d19ed54d897acc0d1a5627379571fb07f0ac

URL: https://github.com/llvm/llvm-project/commit/5316d19ed54d897acc0d1a5627379571fb07f0ac
DIFF: https://github.com/llvm/llvm-project/commit/5316d19ed54d897acc0d1a5627379571fb07f0ac.diff

LOG: [mlir][nvvm] Introduce `nvvm.stmatrix` Op (#69467)

This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively
store one or more matrices across all threads in a warp to the given
address location in shared memory.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cefdd7cc4033a11..9cda7862ccb0fe3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
+  Arguments<(ins LLVM_i8Ptr_shared:$ptr, 
+                 Variadic<I32>:$sources, 
+                 MMALayoutAttr:$layout)> {
+  let summary = "cooperative matrix store";
+  let description = [{
+    Collectively store one or more matrices across all threads in a warp to the
+    location indicated by the address operand $ptr in shared memory.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+  }];
+  
+  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int d = getSources().size();
+      std::string ptx = "stmatrix.sync.aligned";
+      ptx += ".x" + std::to_string(d);
+      if (getLayout() == NVVM::MMALayout::col)
+        ptx += ".trans";
+      if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+      if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+      if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 92df023c797b1bc..3736978505707e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::StMatrixOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected source pointer in memory space 3");
+
+  int numMatrix = getSources().size();
+  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+    return emitOpError("expected num attribute to be 1, 2 or 4");
+
+  return success();
+}
+
 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
   if (typeA == NVVM::WGMMATypes::tf32)
     return 8;

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d0ac9637438a95..3bb0ab90775edf5 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
 
 // -----
 
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, 
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+  llvm.return 
+}
+
+// -----
+
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"


        


More information about the Mlir-commits mailing list