[Mlir-commits] [mlir] 5316d19 - [mlir][nvvm] Introduce `nvvm.stmatrix` Op (#69467)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 01:26:33 PDT 2023
Author: Guray Ozen
Date: 2023-10-19T10:26:28+02:00
New Revision: 5316d19ed54d897acc0d1a5627379571fb07f0ac
URL: https://github.com/llvm/llvm-project/commit/5316d19ed54d897acc0d1a5627379571fb07f0ac
DIFF: https://github.com/llvm/llvm-project/commit/5316d19ed54d897acc0d1a5627379571fb07f0ac.diff
LOG: [mlir][nvvm] Introduce `nvvm.stmatrix` Op (#69467)
This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively
store one or more matrices across all threads in a warp to the given
address location in shared memory.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cefdd7cc4033a11..9cda7862ccb0fe3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
+ Arguments<(ins LLVM_i8Ptr_shared:$ptr,
+ Variadic<I32>:$sources,
+ MMALayoutAttr:$layout)> {
+ let summary = "cooperative matrix store";
+ let description = [{
+ Collectively store one or more matrices across all threads in a warp to the
+ location indicated by the address operand $ptr in shared memory.
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+ }];
+
+ let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int d = getSources().size();
+ std::string ptx = "stmatrix.sync.aligned";
+ ptx += ".x" + std::to_string(d);
+ if (getLayout() == NVVM::MMALayout::col)
+ ptx += ".trans";
+ if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+ if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+ if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 92df023c797b1bc..3736978505707e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+LogicalResult NVVM::StMatrixOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected source pointer in memory space 3");
+
+ int numMatrix = getSources().size();
+ if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+ return emitOpError("expected num attribute to be 1, 2 or 4");
+
+ return success();
+}
+
FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
if (typeA == NVVM::WGMMATypes::tf32)
return 8;
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d0ac9637438a95..3bb0ab90775edf5 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
// -----
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
More information about the Mlir-commits
mailing list