[Mlir-commits] [mlir] [mlir][nvvm] Introduce `nvvm.stmatrix` Op (PR #69467)

Guray Ozen llvmlistbot at llvm.org
Wed Oct 18 07:30:29 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/69467

This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.

>From 0620f0bc8547b3de649adebd0e9084f1429284fc Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 18 Oct 2023 16:29:27 +0200
Subject: [PATCH] [mlir][nvvm] Introduce `nvvm.stmatrix` Op

This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 28 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 13 +++++++++
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 24 ++++++++++++++++
 3 files changed, 65 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cefdd7cc4033a11..7b9e83371eb4450 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1186,6 +1186,34 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
+  Arguments<(ins LLVM_i8Ptr_shared:$ptr, 
+                 Variadic<I32>:$sources, 
+                 MMALayoutAttr:$layout)> {
+  let summary = "cooperative matrix store";
+  let description = [{
+    Collectively store one or more matrices across all threads in a warp to the
+    location indicated by the address operand $ptr in shared memory.
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html)
+  }];
+  
+  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int d = getSources().size();
+      std::string ptx = "stmatrix.sync.aligned";
+      ptx += ".x" + std::to_string(d);
+      if (getLayout() == NVVM::MMALayout::col)
+        ptx += ".trans";
+      if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+      if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+      if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 92df023c797b1bc..3736978505707e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::StMatrixOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected source pointer in memory space 3");
+
+  int numMatrix = getSources().size();
+  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+    return emitOpError("expected num attribute to be 1, 2 or 4");
+
+  return success();
+}
+
 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
   if (typeA == NVVM::WGMMATypes::tf32)
     return 8;
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d0ac9637438a95..3bb0ab90775edf5 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
 
 // -----
 
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, 
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+  llvm.return 
+}
+
+// -----
+
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"



More information about the Mlir-commits mailing list