[Mlir-commits] [mlir] [MLIR][NVVM] Add barrier.arrive (PR #85412)

Guray Ozen llvmlistbot at llvm.org
Tue Mar 19 06:14:19 PDT 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/85412

>From 7b1c27b6e92d9f4fdac4514e39079d7e4ca86c7e Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Fri, 15 Mar 2024 15:11:53 +0000
Subject: [PATCH 1/2] [MLIR][NVVM] Add barrier.arrive

PR adds `nvvm.barrier.arrive` Op. It is useful op for producer consumer modeling.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 32 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  7 ++++
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 14 ++++++++
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 12 +++++++
 4 files changed, 65 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8ec8e16f75c94b..1eca9c2642111f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -409,6 +409,38 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
   let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
 }
 
+def NVVM_BarrierArriveOp : NVVM_Op<"barrier.arrive", 
+                  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+                  AttrSizedOperandSegments]> 
+{
+  let arguments = (ins     
+    Optional<I32>:$barrierId,
+    Optional<I32>:$numberOfThreads);
+  let hasVerifier = 1;
+  let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
+  let description = [{
+    Thread that executes this op announces their arrival at the barrier with 
+    given id and continue their execution.
+
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
+  }];
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      std::string ptx = "bar.arrive ";
+      if (getNumberOfThreads() && getBarrierId()) {
+        ptx += "%0, %1";
+      } else if (getBarrierId()) {
+        ptx += "%0";
+      } else {
+        ptx += "0";
+      }
+      ptx += ";";
+      return ptx;
+    }
+  }];
+}
+
 def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
   let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 4780ec09b81b9b..d7784686bd02df 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1029,6 +1029,13 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::BarrierArriveOp::verify() {
+  if (getNumberOfThreads() && !getBarrierId())
+    return emitOpError(
+        "barrier id is missing, it should be set between 0 to 15");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0ac7331e1f6987..8367982ed0f2c7 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -680,3 +680,17 @@ func.func @fence_proxy() {
   nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: @llvm_nvvm_barrier_arrive
+// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
+llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0;", "" : () -> () 
+  nvvm.barrier.arrive
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0;", "r" %[[barId]] : (i32) -> ()
+  nvvm.barrier.arrive id = %barID
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1;", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
+  nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
+  llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index f35393c5e95748..e506cc98ca67c6 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -55,6 +55,18 @@ llvm.func @llvm_nvvm_barrier(%barId : i32, %numberOfThreads : i32) {
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_barrier_arrive
+// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
+llvm.func @llvm_nvvm_barrier_arrive(%barId : i32, %numberOfThreads : i32) {
+  // CHECK: nvvm.barrier.arrive
+  nvvm.barrier.arrive
+  // CHECK: nvvm.barrier.arrive id = %[[barId]]
+  nvvm.barrier.arrive id = %barId
+  // CHECK: nvvm.barrier.arrive id = %[[barId]] number_of_threads = %[[numberOfThreads]]
+  nvvm.barrier.arrive id = %barId number_of_threads = %numberOfThreads
+  llvm.return
+}
+
 // CHECK-LABEL: @llvm_nvvm_cluster_arrive
 func.func @llvm_nvvm_cluster_arrive() {
   // CHECK: nvvm.cluster.arrive

>From cdd3fbffd5e5f2c396be21f85779208c8d477ae3 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 19 Mar 2024 13:13:24 +0000
Subject: [PATCH 2/2] address comments

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 27 ++++++++-----------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 14 ++--------
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  8 +++---
 mlir/test/Dialect/LLVMIR/nvvm.mlir            |  6 ++---
 4 files changed, 18 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1eca9c2642111f..728e92c9dc8dcf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -409,33 +409,28 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
   let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
 }
 
-def NVVM_BarrierArriveOp : NVVM_Op<"barrier.arrive", 
-                  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
-                  AttrSizedOperandSegments]> 
+def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive"> 
 {
-  let arguments = (ins     
-    Optional<I32>:$barrierId,
-    Optional<I32>:$numberOfThreads);
-  let hasVerifier = 1;
-  let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
+  let arguments = (ins Optional<I32>:$barrierId, I32:$numberOfThreads);
+
   let description = [{
     Thread that executes this op announces their arrival at the barrier with 
     given id and continue their execution.
 
+    The default barrier id is 0 that is similar to `nvvm.barrier` Op. When 
+    `barrierId` is not present, the default barrier id is used. 
+
     [For more information, see PTX ISA]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
   }];
+  
+  let assemblyFormat = "(`id` `=` $barrierId^)? `number_of_threads` `=` $numberOfThreads attr-dict";
+
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       std::string ptx = "bar.arrive ";
-      if (getNumberOfThreads() && getBarrierId()) {
-        ptx += "%0, %1";
-      } else if (getBarrierId()) {
-        ptx += "%0";
-      } else {
-        ptx += "0";
-      }
-      ptx += ";";
+      if (getBarrierId()) { ptx += "%0, %1'"; } 
+      else { ptx += "0, %0;"; }
       return ptx;
     }
   }];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d7784686bd02df..9e8407451a0855 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -214,8 +214,7 @@ void MmaOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
 
   // Print the types of the operands and result.
-  p << " : "
-    << "(";
+  p << " : " << "(";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
                                              frags[2].regs[0].getType()},
@@ -956,9 +955,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << "},";
   // Need to map read/write registers correctly.
   regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << ","
-     << " $" << (regCnt + 1) << ","
-     << " p";
+  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
   if (getTypeD() != WGMMATypes::s32) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
@@ -1029,13 +1026,6 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
-LogicalResult NVVM::BarrierArriveOp::verify() {
-  if (getNumberOfThreads() && !getBarrierId())
-    return emitOpError(
-        "barrier id is missing, it should be set between 0 to 15");
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 8367982ed0f2c7..8920bf86d89b1d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -686,11 +686,9 @@ func.func @fence_proxy() {
 // CHECK-LABEL: @llvm_nvvm_barrier_arrive
 // CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
 llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0;", "" : () -> () 
-  nvvm.barrier.arrive
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0;", "r" %[[barId]] : (i32) -> ()
-  nvvm.barrier.arrive id = %barID
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1;", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0, $0;", "r" %[[numberOfThreads]] : (i32) -> ()
+  nvvm.barrier.arrive number_of_threads = %numberOfThreads
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1'", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
   nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
   llvm.return
 }
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e506cc98ca67c6..de2904d15b647b 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -58,10 +58,8 @@ llvm.func @llvm_nvvm_barrier(%barId : i32, %numberOfThreads : i32) {
 // CHECK-LABEL: @llvm_nvvm_barrier_arrive
 // CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
 llvm.func @llvm_nvvm_barrier_arrive(%barId : i32, %numberOfThreads : i32) {
-  // CHECK: nvvm.barrier.arrive
-  nvvm.barrier.arrive
-  // CHECK: nvvm.barrier.arrive id = %[[barId]]
-  nvvm.barrier.arrive id = %barId
+  // CHECK: nvvm.barrier.arrive number_of_threads = %[[numberOfThreads]]
+  nvvm.barrier.arrive number_of_threads = %numberOfThreads
   // CHECK: nvvm.barrier.arrive id = %[[barId]] number_of_threads = %[[numberOfThreads]]
   nvvm.barrier.arrive id = %barId number_of_threads = %numberOfThreads
   llvm.return



More information about the Mlir-commits mailing list