[Mlir-commits] [mlir] 8819f87 - [MLIR][NVVM] Add barrier.arrive (#85412)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 08:51:36 PDT 2024
Author: Guray Ozen
Date: 2024-03-19T16:51:32+01:00
New Revision: 8819f8799868a95bba24c0a574a9b1455e63b63d
URL: https://github.com/llvm/llvm-project/commit/8819f8799868a95bba24c0a574a9b1455e63b63d
DIFF: https://github.com/llvm/llvm-project/commit/8819f8799868a95bba24c0a574a9b1455e63b63d.diff
LOG: [MLIR][NVVM] Add barrier.arrive (#85412)
PR adds `nvvm.barrier.arrive` Op. It is useful op for producer consumer
modeling.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8ec8e16f75c94b..728e92c9dc8dcf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -409,6 +409,33 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
}
+def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">
+{
+ let arguments = (ins Optional<I32>:$barrierId, I32:$numberOfThreads);
+
+ let description = [{
+ Thread that executes this op announces their arrival at the barrier with
+ given id and continue their execution.
+
+ The default barrier id is 0 that is similar to `nvvm.barrier` Op. When
+ `barrierId` is not present, the default barrier id is used.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
+ }];
+
+ let assemblyFormat = "(`id` `=` $barrierId^)? `number_of_threads` `=` $numberOfThreads attr-dict";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ std::string ptx = "bar.arrive ";
+ if (getBarrierId()) { ptx += "%0, %1'"; }
+ else { ptx += "0, %0;"; }
+ return ptx;
+ }
+ }];
+}
+
def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 4780ec09b81b9b..9e8407451a0855 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -214,8 +214,7 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : "
- << "(";
+ p << " : " << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -956,9 +955,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << ","
- << " $" << (regCnt + 1) << ","
- << " p";
+ ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0ac7331e1f6987..8920bf86d89b1d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -680,3 +680,15 @@ func.func @fence_proxy() {
nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
func.return
}
+
+// -----
+
+// CHECK-LABEL: @llvm_nvvm_barrier_arrive
+// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
+llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0, $0;", "r" %[[numberOfThreads]] : (i32) -> ()
+ nvvm.barrier.arrive number_of_threads = %numberOfThreads
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1'", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
+ nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index f35393c5e95748..de2904d15b647b 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -55,6 +55,16 @@ llvm.func @llvm_nvvm_barrier(%barId : i32, %numberOfThreads : i32) {
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_barrier_arrive
+// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
+llvm.func @llvm_nvvm_barrier_arrive(%barId : i32, %numberOfThreads : i32) {
+ // CHECK: nvvm.barrier.arrive number_of_threads = %[[numberOfThreads]]
+ nvvm.barrier.arrive number_of_threads = %numberOfThreads
+ // CHECK: nvvm.barrier.arrive id = %[[barId]] number_of_threads = %[[numberOfThreads]]
+ nvvm.barrier.arrive id = %barId number_of_threads = %numberOfThreads
+ llvm.return
+}
+
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
func.func @llvm_nvvm_cluster_arrive() {
// CHECK: nvvm.cluster.arrive
More information about the Mlir-commits
mailing list