[Mlir-commits] [mlir] [mlir][nvvm] Introduce `nvvm.stmatrix` Op (PR #69467)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 18 07:31:44 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.
---
Full diff: https://github.com/llvm/llvm-project/pull/69467.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+29)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+13)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+24)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cefdd7cc4033a11..9cda7862ccb0fe3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
+ Arguments<(ins LLVM_i8Ptr_shared:$ptr,
+ Variadic<I32>:$sources,
+ MMALayoutAttr:$layout)> {
+ let summary = "cooperative matrix store";
+ let description = [{
+ Collectively store one or more matrices across all threads in a warp to the
+ location indicated by the address operand $ptr in shared memory.
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+ }];
+
+ let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int d = getSources().size();
+ std::string ptx = "stmatrix.sync.aligned";
+ ptx += ".x" + std::to_string(d);
+ if (getLayout() == NVVM::MMALayout::col)
+ ptx += ".trans";
+ if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+ if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+ if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 92df023c797b1bc..3736978505707e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+LogicalResult NVVM::StMatrixOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected source pointer in memory space 3");
+
+ int numMatrix = getSources().size();
+ if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+ return emitOpError("expected num attribute to be 1, 2 or 4");
+
+ return success();
+}
+
FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
if (typeA == NVVM::WGMMATypes::tf32)
return 8;
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d0ac9637438a95..3bb0ab90775edf5 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
// -----
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
``````````
</details>
https://github.com/llvm/llvm-project/pull/69467
More information about the Mlir-commits
mailing list