[Mlir-commits] [mlir] [mlir:nvgpu] Make `mbarrier.try_wait` fallable (PR #96508)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 24 08:20:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chris Jones (chr1sj0nes)
<details>
<summary>Changes</summary>
The op will make a single attempt to wait, then fail, rather than looping. This matches the behaviour described in the docs, and means that the `ticks` parameter is actually meaningful.
---
Full diff: https://github.com/llvm/llvm-project/pull/96508.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+10-26)
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+2-1)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-3)
- (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+1-1)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+20-21)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+11-25)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4d48b3de7a57e..eac71a59841e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -323,40 +323,24 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.ex
}];
}
-def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+ let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
+ return std::string("mbarrier.try_wait.parity.b64 %0, [%1], %2, %3;");
}
}];
}
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
+ let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
+ return std::string("mbarrier.try_wait.parity.shared.b64 %0, [%1], %2, %3;");
}
}];
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe..d946a1cf49ee2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -614,10 +614,11 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
Example:
```mlir
- nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+ %isComplete = nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
```
}];
let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
+ let results = (outs I1:$waitComplete);
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";
}
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..bcc1d9eb7766c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -958,15 +958,16 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+ Type retType = rewriter.getI1Type();
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
+ op, retType, barrier, phase, ticks);
return success();
}
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
- phase, ticks);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(
+ op, retType, barrier, phase, ticks);
return success();
}
};
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 4e256aea0be37..356c1621f81e6 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -1018,7 +1018,7 @@ void HopperBuilder::buildTryWaitParity(
Value ticksBeforeRetry =
rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
+ rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, i1, barrier, parity,
ticksBeforeRetry, zero);
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 86a552c03a473..5d8e5d1e5a2db 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -558,7 +558,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
}
// CHECK-LABEL: func @mbarrier_txcount
-func.func @mbarrier_txcount() {
+func.func @mbarrier_txcount() -> i1 {
%num_threads = arith.constant 128 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
@@ -568,50 +568,49 @@ func.func @mbarrier_txcount() {
%barrier = nvgpu.mbarrier.create -> !barrierType
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
-
+
%tidxreg = nvvm.read.ptx.sreg.tid.x : i32
%tidx = arith.index_cast %tidxreg : i32 to index
- %cnd = arith.cmpi eq, %tidx, %c0 : index
+ %cnd = arith.cmpi eq, %tidx, %c0 : index
scf.if %cnd {
%txcount = arith.constant 256 : index
- // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
- scf.yield
+ scf.yield
} else {
%txcount = arith.constant 0 : index
- // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
- scf.yield
+ scf.yield
}
-
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
- // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
- nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+ // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
- func.return
+ func.return %isDone : i1
}
// CHECK-LABEL: func @mbarrier_txcount_pred
-func.func @mbarrier_txcount_pred() {
+func.func @mbarrier_txcount_pred() -> i1 {
%mine = arith.constant 1 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
// CHECK: %[[S2:.+]] = gpu.thread_id x
// CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
- %c0 = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
%tidx = gpu.thread_id x
%pred = arith.cmpi eq, %tidx, %c0 : index
@@ -619,25 +618,25 @@ func.func @mbarrier_txcount_pred() {
%barrier = nvgpu.mbarrier.create -> !barrierType
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
%txcount = arith.constant 256 : index
- // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
- // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
- nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+ // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
- func.return
+ func.return %isDone : i1
}
// CHECK-LABEL: func @async_tma_load
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e..a2ec210985863 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -32,35 +32,21 @@ llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount
}
// CHECK-LABEL: @init_mbarrier_try_wait_shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred P1;
- // CHECK-SAME: LAB_WAIT:
- // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
- // CHECK-SAME: @P1 bra.uni DONE;
- // CHECK-SAME: bra.uni LAB_WAIT;
- // CHECK-SAME: DONE:
- // CHECK-SAME: }",
- // CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
- llvm.return
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) -> i1 {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "mbarrier.try_wait.parity.shared.b64 $0, [$1], $2, $3;",
+ // CHECK-SAME: "=b,r,r,r"
+ %isDone = nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1
+ llvm.return %isDone : i1
}
// CHECK-LABEL: @init_mbarrier_try_wait
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32) -> i1 {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred P1;
- // CHECK-SAME: LAB_WAIT:
- // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
- // CHECK-SAME: @P1 bra.uni DONE;
- // CHECK-SAME: bra.uni LAB_WAIT;
- // CHECK-SAME: DONE:
- // CHECK-SAME: }",
- // CHECK-SAME: "l,r,r"
- nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
- llvm.return
+ // CHECK-SAME: "mbarrier.try_wait.parity.b64 $0, [$1], $2, $3;",
+ // CHECK-SAME: "=b,l,r,r"
+ %isDone = nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1
+ llvm.return %isDone : i1
}
// CHECK-LABEL: @async_cp
``````````
</details>
https://github.com/llvm/llvm-project/pull/96508
More information about the Mlir-commits
mailing list