[Mlir-commits] [mlir] [mlir:nvgpu] Make `mbarrier.try_wait` fallable (PR #96508)

Chris Jones llvmlistbot at llvm.org
Mon Jun 24 08:20:05 PDT 2024


https://github.com/chr1sj0nes created https://github.com/llvm/llvm-project/pull/96508

The op will make a single attempt to wait, then fail, rather than looping. This matches the behaviour described in the docs, and means that the `ticks` parameter is actually meaningful.

>From 316efd4ee6b86237267fc296892c7d03aa428333 Mon Sep 17 00:00:00 2001
From: Chris Jones <cjfj at deepmind.com>
Date: Mon, 24 Jun 2024 16:16:07 +0100
Subject: [PATCH] [mlir:nvgpu] Make `mbarrier.try_wait` non-blocking.

This matches the behaviour described in the docs, and means that the `ticks` parameter is actually meaningful.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 36 +++++-----------
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   |  3 +-
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  7 ++--
 .../NVGPU/TransformOps/NVGPUTransformOps.cpp  |  2 +-
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 41 +++++++++----------
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 36 +++++-----------
 6 files changed, 48 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4d48b3de7a57e..eac71a59841e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -323,40 +323,24 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.ex
   }];
 }
 
-def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
-  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {  
-  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
-      return std::string(
-        "{\n\t"
-        ".reg .pred       P1; \n\t"
-        "LAB_WAIT: \n\t"
-        "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
-        "@P1 bra.uni DONE; \n\t"
-        "bra.uni     LAB_WAIT; \n\t"
-        "DONE: \n\t"
-        "}"
-      ); 
+      return std::string("mbarrier.try_wait.parity.b64 %0, [%1], %2, %3;");
     }
   }];
 }
 
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
-  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {  
-  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
+  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
-      return std::string(
-        "{\n\t"
-        ".reg .pred       P1; \n\t"
-        "LAB_WAIT: \n\t"
-        "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
-        "@P1 bra.uni DONE; \n\t"
-        "bra.uni     LAB_WAIT; \n\t"
-        "DONE: \n\t"
-        "}"
-      ); 
+      return std::string("mbarrier.try_wait.parity.shared.b64 %0, [%1], %2, %3;");
     }
   }];
 }
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe..d946a1cf49ee2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -614,10 +614,11 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
 
     Example:
     ```mlir
-      nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+      %isComplete = nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
     ```
   }];
   let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
+  let results = (outs I1:$waitComplete);
   let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";  
 }
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..bcc1d9eb7766c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -958,15 +958,16 @@ struct NVGPUMBarrierTryWaitParityLowering
     Value ticks = truncToI32(b, adaptor.getTicks());
     Value phase =
         b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+    Type retType = rewriter.getI1Type();
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
-          op, barrier, phase, ticks);
+          op, retType, barrier, phase, ticks);
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
-                                                               phase, ticks);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(
+        op, retType, barrier, phase, ticks);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 4e256aea0be37..356c1621f81e6 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -1018,7 +1018,7 @@ void HopperBuilder::buildTryWaitParity(
   Value ticksBeforeRetry =
       rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
+  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, i1, barrier, parity,
                                                   ticksBeforeRetry, zero);
 }
 
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 86a552c03a473..5d8e5d1e5a2db 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -558,7 +558,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
 }
 
 // CHECK-LABEL: func @mbarrier_txcount
-func.func @mbarrier_txcount() {
+func.func @mbarrier_txcount() -> i1 {
     %num_threads = arith.constant 128 : index
     // CHECK: %[[c0:.+]] = arith.constant 0 : index
     // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
@@ -568,50 +568,49 @@ func.func @mbarrier_txcount() {
     %barrier = nvgpu.mbarrier.create -> !barrierType
 
     // CHECK: %[[barStr:.+]] =  builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
     nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
-    
+
     %tidxreg = nvvm.read.ptx.sreg.tid.x : i32
     %tidx = arith.index_cast %tidxreg : i32 to index
-    %cnd = arith.cmpi eq, %tidx, %c0 : index  
+    %cnd = arith.cmpi eq, %tidx, %c0 : index
 
     scf.if %cnd {
       %txcount = arith.constant 256 : index
-      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
       // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
       // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
       nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
-      scf.yield 
+      scf.yield
     } else {
       %txcount = arith.constant 0 : index
-      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
       // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
       // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
       nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
-      scf.yield 
+      scf.yield
     }
-      
 
     %phase_c0 = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
-    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
-    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
-    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+    // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+    %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
 
-    func.return 
+    func.return %isDone : i1
 }
 
 // CHECK-LABEL: func @mbarrier_txcount_pred
-func.func @mbarrier_txcount_pred() {
+func.func @mbarrier_txcount_pred() -> i1 {
     %mine = arith.constant 1 : index
     // CHECK: %[[c0:.+]] = arith.constant 0 : index
     // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
     // CHECK: %[[S2:.+]] = gpu.thread_id  x
     // CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
-    %c0 = arith.constant 0 : index  
+    %c0 = arith.constant 0 : index
     %tidx = gpu.thread_id x
     %pred = arith.cmpi eq, %tidx, %c0 : index
 
@@ -619,25 +618,25 @@ func.func @mbarrier_txcount_pred() {
     %barrier = nvgpu.mbarrier.create -> !barrierType
 
     // CHECK: %[[barStr:.+]] =  builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
     nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
     
     %txcount = arith.constant 256 : index
-    // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
     nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
 
     %phase_c0 = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
-    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
-    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
-    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+    // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+    %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
 
-    func.return 
+    func.return %isDone : i1
 }
 
 // CHECK-LABEL: func @async_tma_load
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e..a2ec210985863 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -32,35 +32,21 @@ llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount
 }
 
 // CHECK-LABEL: @init_mbarrier_try_wait_shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred       P1;
-  // CHECK-SAME: LAB_WAIT: 
-  // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
-  // CHECK-SAME: @P1 bra.uni DONE;
-  // CHECK-SAME: bra.uni     LAB_WAIT;
-  // CHECK-SAME: DONE:
-  // CHECK-SAME: }",
-  // CHECK-SAME: "r,r,r"
-   nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
-  llvm.return
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) -> i1 {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+  // CHECK-SAME: "mbarrier.try_wait.parity.shared.b64 $0, [$1], $2, $3;",
+  // CHECK-SAME: "=b,r,r,r"
+  %isDone = nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1
+  llvm.return %isDone : i1
 }
 
 // CHECK-LABEL: @init_mbarrier_try_wait
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32) -> i1 {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred       P1;
-  // CHECK-SAME: LAB_WAIT: 
-  // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
-  // CHECK-SAME: @P1 bra.uni DONE;
-  // CHECK-SAME: bra.uni     LAB_WAIT;
-  // CHECK-SAME: DONE:
-  // CHECK-SAME: }",
-  // CHECK-SAME: "l,r,r"
-  nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
-  llvm.return
+  // CHECK-SAME: "mbarrier.try_wait.parity.b64 $0, [$1], $2, $3;",
+  // CHECK-SAME: "=b,l,r,r"
+  %isDone = nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1
+  llvm.return %isDone : i1
 }
 
 // CHECK-LABEL: @async_cp



More information about the Mlir-commits mailing list