[Mlir-commits] [mlir] 836dbb8 - [mlir][nvgpu] Add `mbarrier.arrive.expect_tx` and `mbarrier.try_wait.parity`
Guray Ozen
llvmlistbot at llvm.org
Thu Jul 20 04:48:35 PDT 2023
Author: Guray Ozen
Date: 2023-07-20T13:48:30+02:00
New Revision: 836dbb8522abcf27f10bcf0134872b93676d9064
URL: https://github.com/llvm/llvm-project/commit/836dbb8522abcf27f10bcf0134872b93676d9064
DIFF: https://github.com/llvm/llvm-project/commit/836dbb8522abcf27f10bcf0134872b93676d9064.diff
LOG: [mlir][nvgpu] Add `mbarrier.arrive.expect_tx` and `mbarrier.try_wait.parity`
This work adds two Ops:
`mbarrier.arrive.expect_tx` performs expect_tx `mbarrier.barrier` returns `mbarrier.barrier.token`
`mbarrier.try_wait.parity` waits on `mbarrier.barrier` and `mbarrier.barrier.token`
`mbarrier.arrive.expect_tx` is one of the requirement to enable H100 TMA support.
Depends on D154074 D154076 D154059 D154060
Reviewed By: qcolombet
Differential Revision: https://reviews.llvm.org/D154094
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ef17a6c2ac4ff5..c3cafe6b33c6c7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -372,53 +372,59 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
}
def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Results<(outs LLVM_Type:$res)>,
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
+ let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"); }
+ std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
}];
}
def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Results<(outs LLVM_Type:$res)>,
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
+ let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"); }
+ std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
}];
}
def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> {
- let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {
+ let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
- return std::string("{\n\t"
- ".reg .pred P1; \n\t"
- "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t"
- "selp.b32 %0, 1, 0, P1; \n\t"
- "}");
+ return std::string(
+ "{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}"
+ );
}
}];
}
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> {
- let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {
+ let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
- return std::string("{\n\t"
- ".reg .pred P1; \n\t"
- "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
- "selp.b32 %0, 1, 0, P1; \n\t"
- "}");
+ return std::string(
+ "{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}"
+ );
}
}];
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 9e783d0c928e16..eb0bdee0b55f17 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -469,4 +469,44 @@ def NVGPU_MBarrierArriveNoCompleteOp : NVGPU_Op<"mbarrier.arrive.nocomplete", []
let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `->` type($token)";
}
+def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> {
+ let summary = "Performs expect_tx operation on the `nvgpu.mbarrier.arrive`";
+ let description = [{
+ A thread executing the Op performs an expect-tx operation on the mbarrier
+ object at the location specified by the address operand $barrier. The
+ expect-tx operation, with an $txcount argument, increases the tx-count of
+ an mbarrier object by the value specified by $txcount. This makes the
+ current phase of the mbarrier object to expect and track the completion of
+ additional asynchronous transactions.
+
+ The `$txCount` specifies the number of element to the expect-tx operation.
+
+ Example:
+ ```mlir
+ nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+ ```
+ }];
+ let arguments = (ins NVGPU_MBarrier:$barrier,
+ Index:$txcount);
+ let assemblyFormat = "$barrier `,` $txcount attr-dict `:` type($barrier)";
+}
+
+def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
+ let summary = "Waits for the `nvgpu.mbarrier` to complete its current phase.";
+ let description = [{
+ Checks whether the mbarrier object has completed the phase. It is is a
+ potentially blocking instruction which tests for the completion of the
+ phase. Suspended thread resumes execution when the specified phase completes
+ OR before the phase completes following a system-dependent time limit.
+
+ Example:
+ ```mlir
+ nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+ ```
+
+ }];
+ let arguments = (ins NVGPU_MBarrier:$barrier, Index:$phase, Index:$ticks);
+ let assemblyFormat = "$barrier `,` $phase `,` $ticks attr-dict `:` type($barrier)";
+}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b8adef23b06e6a..26be5c03546c29 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -25,6 +25,17 @@ namespace mlir {
using namespace mlir;
+/// GPU has 32 bit registers, this function truncates values when larger width
+/// is not needed.
+static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
+ Value value) {
+ Type type = value.getType();
+ assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
+ if (type.getIntOrFloatBitWidth() <= 32)
+ return value;
+ return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
+}
+
/// Returns the type for the intrinsic given the vectorResultType of the
/// `gpu.mma.sync` operation.
static Type inferIntrinsicResultType(Type vectorResultType) {
@@ -850,6 +861,55 @@ struct NVGPUMBarrierTestWaitLowering
}
};
+struct NVGPUMBarrierArriveExpectTxLowering
+ : public ConvertOpToLLVMPattern<nvgpu::MBarrierArriveExpectTxOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::MBarrierArriveExpectTxOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
+ op.getBarrier(), adaptor.getBarrier());
+ Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
+
+ if (isMbarrierShared(op.getBarrier().getType())) {
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
+ op, barrier, txcount);
+ return success();
+ }
+
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
+ txcount);
+ return success();
+ }
+};
+
+struct NVGPUMBarrierTryWaitParityLowering
+ : public ConvertOpToLLVMPattern<nvgpu::MBarrierTryWaitParityOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::MBarrierTryWaitParityOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
+ op.getBarrier(), adaptor.getBarrier());
+ Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks());
+ Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase());
+
+ if (isMbarrierShared(op.getBarrier().getType())) {
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
+ op, barrier, phase, ticks);
+ return success();
+ }
+
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
+ phase, ticks);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -859,7 +919,9 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
- NVGPUMBarrierTestWaitLowering, // nvgpu.try_wait_parity
+ NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
+ NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
+ NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 7a7f65f3d945bd..c7a0c7f4b3ea94 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -558,3 +558,49 @@ func.func @mbarrier_nocomplete() {
func.return
}
+
+
+// -----
+!barrierType = !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+!tokenType = !nvgpu.mbarrier.token
+
+// CHECK-LABEL: func @mbarrier_txcount
+func.func @mbarrier_txcount() {
+ %num_threads = arith.constant 128 : index
+
+ // CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier : memref<1xi64, 3>
+ %barrier = nvgpu.mbarrier.create -> !barrierType
+
+ // CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[barPtr:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
+ nvgpu.mbarrier.init %barrier, %num_threads : !barrierType
+
+ %c0 = arith.constant 0 : index
+ %tidxreg = nvvm.read.ptx.sreg.tid.x : i32
+ %tidx = arith.index_cast %tidxreg : i32 to index
+ %cnd = arith.cmpi eq, %tidx, %c0 : index
+
+ scf.if %cnd {
+ %txcount = arith.constant 256 : index
+ // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType
+ scf.yield
+ } else {
+ %txcount = arith.constant 0 : index
+ // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType
+ scf.yield
+ }
+
+
+ %phase = arith.constant 0 : index
+ %ticks = arith.constant 10000000 : index
+ // CHECK: %[[barPtr3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !barrierType
+
+ func.return
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5d3218ef1c7f5d..0d93072b695243 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -1,31 +1,31 @@
// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i64 {
- //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 $0, [$1], $2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64
- %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
- llvm.return %res : i64
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
+ //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
+ nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+ llvm.return
}
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i64 {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 $0, [$1], $2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64
- %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64
- llvm.return %res : i64
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r"
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+ llvm.return
}
// CHECK-LABEL : @init_mbarrier_try_wait.parity.shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 {
- // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
- %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32
- llvm.return %res : i32
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2; \0A\09 at P1 bra.uni DONE; \0A\09bra.uni LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r"
+ nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+ llvm.return
}
// CHECK-LABEL : @init_mbarrier_try_wait.parity
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %token : i32) -> i32{
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
- %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32
- llvm.return %res : i32
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.b64 P1, [$0], $1, $2; \0A\09 at P1 bra.uni DONE; \0A\09bra.uni LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r"
+ nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
+ llvm.return
}
// CHECK-LABEL : @async_cp
More information about the Mlir-commits
mailing list