[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.mbarrier.group` for multiple mbarrier use (PR #65951)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 11 05:07:39 PDT 2023
llvmbot wrote:
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
A common practice involves the creation of multiple `mbarrier` objects, see an example below. This is particularly valuable in scenarios like software pipelining for GEMM, where we need to generate multiple barriers dynamically use and wait them in a loop.
PR improves `nvgpu.mbarrier.barrier` type into the `nvgpu.mbarrier.group`. All `mbarrier` related Ops now uses this type. Consequently, these Ops are now capable of managing multiple barriers seamlessly.
Having `num_barriers = 4` helps us to locate mbarrier object(s) into static shared memory. We could make the value dynamic that requires dynamic shared memory it would complicate the codegen.
```
%barriers = nvgpu.mbarrier.create -> !nvgpu.mbarrier.group<3, num_barriers = 4>
nvgpu.mbarrier.init %barriers[%c0], %num_threads : !nvgpu.mbarrier.group<3, num_barriers = 4>
nvgpu.mbarrier.init %barriers[%c1], %num_threads : !nvgpu.mbarrier.group<3, num_barriers = 4>
nvgpu.mbarrier.init %barriers[%c2], %num_threads : !nvgpu.mbarrier.group<3, num_barriers = 4>
nvgpu.mbarrier.init %barriers[%c3], %num_threads : !nvgpu.mbarrier.group<3, num_barriers = 4>
...
scf.for %i = %c0 to %n step %c1 {
nvgpu.mbarrier.try_wait %barriers[ (i % 4) ] ...
// ... Do work once mbarrier is ready
nvgpu.mbarrier.arrive.expect_tx %barriers[ (i + 3 % 4) ] ...
}
```
We will have mbarrier usages like below:
```
expect_tx[0]
expect_tx[1]
expect_tx[2]
Loop:
try_wait mbarrier[0], expect_tx[3]
try_wait mbarrier[1], expect_tx[0]
try_wait mbarrier[2], expect_tx[1]
try_wait mbarrier[3], expect_tx[2]
...
```
--
Patch is 43.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65951.diff
5 Files Affected:
- (modified) mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h (+3-3)
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+34-28)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+71-64)
- (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+23-18)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+71-31)
<pre>
diff --git a/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h b/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
index 8c5667cd417f0d4..4b8d5c5fe2a893d 100644
--- a/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
+++ b/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
@@ -23,15 +23,15 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace nvgpu {
-class MBarrierType;
+class MBarrierGroupType;
/// Returns the memory space attribute of the mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context,
- MBarrierType barrierType);
+ MBarrierGroupType barrierType);
/// Return the memref type that can be used to represent an mbarrier object.
MemRefType getMBarrierMemrefType(MLIRContext *context,
- MBarrierType barrierType);
+ MBarrierGroupType barrierType);
} // namespace nvgpu
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a3245bf9196eed1..cc09945b477d8fa 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -135,20 +135,26 @@ def NVGPU_DeviceAsyncToken : NVGPU_Type<"DeviceAsyncToken",
}];
}
-def NVGPU_MBarrier : NVGPU_Type<"MBarrier", "mbarrier.barrier", []> {
+def NVGPU_MBarrierGroup : NVGPU_Type<"MBarrierGroup", "mbarrier.group", []> {
let summary = "mbarrier barrier type";
let description = [{
- This is the type for a mbarrier object in shared memory that is used
- to synchronize a variable number of threads.
+ This is the type for one or more mbarrier object in shared memory that is
+ used to synchronize a variable number of threads.
- The mbarrier object is 64 bit with 8 byte alignment. The mbarrier object
- can be initiated and invalidated.
+ If `num_barriers` is not set, the number of mbarrier objects is 1.
- See for more details:
- https://docs.nvidia.com/cuda/parallel-thread-execution/#size-and-alignment-of-mbarrier-object
+ A mbarrier object is 64 bit with 8 byte alignment. The mbarrier object
+ can be initiated and invalidated.
+
+ [See for more details in PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#size-and-alignment-of-mbarrier-object)
}];
- let parameters = (ins "Attribute":$memorySpace);
+ let parameters = (ins "Attribute":$memorySpace, DefaultValuedParameter<"unsigned", "1">:$num_barriers);
let assemblyFormat = "`<` struct(params) `>`";
+ let builders = [
+ TypeBuilder<(ins "Attribute":$memorySpace), [{
+ return $_get($_ctxt, memorySpace, 1);
+ }]>
+ ];
}
def NVGPU_MBarrierToken : NVGPU_Type<"MBarrierToken", "mbarrier.token", []> { }
@@ -473,7 +479,7 @@ def NVGPU_DeviceAsyncWaitOp : NVGPU_Op<"device_async_wait", []> {
def NVGPU_MBarrierCreateOp : NVGPU_Op<"mbarrier.create", []> {
let summary = "Creates a `nvgpu.mbarrier` object.";
let description = [{
- The Op generates an `mbarrier` object, which is a barrier created in
+ The Op generates one or more `mbarrier` object, which is a barrier created in
shared memory and supports various synchronization behaviors for threads.
The `mbarrier` object has the following type and alignment requirements:
@@ -485,9 +491,9 @@ def NVGPU_MBarrierCreateOp : NVGPU_Op<"mbarrier.create", []> {
```
}];
let arguments = (ins);
- let results = (outs NVGPU_MBarrier:$barrier);
+ let results = (outs NVGPU_MBarrierGroup:$barriers);
let assemblyFormat = [{
- attr-dict `->` type($barrier)
+ attr-dict `->` type($barriers)
}];
}
@@ -503,8 +509,8 @@ def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier.init", []> {
nvgpu.mbarrier.init %barrier, %num_threads : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
```
}];
- let arguments = (ins NVGPU_MBarrier:$barrier, Index:$count);
- let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier)";
+ let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$count, Index:$mbarId);
+ let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count attr-dict `:` type($barriers)";
}
def NVGPU_MBarrierTestWaitOp : NVGPU_Op<"mbarrier.test.wait", []> {
@@ -518,9 +524,9 @@ def NVGPU_MBarrierTestWaitOp : NVGPU_Op<"mbarrier.test.wait", []> {
%isComplete = nvgpu.mbarrier.test.wait %barrier, %token : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>, !nvgpu.mbarrier.token
```
}];
- let arguments = (ins NVGPU_MBarrier:$barrier, NVGPU_MBarrierToken:$token);
+ let arguments = (ins NVGPU_MBarrierGroup:$barriers, NVGPU_MBarrierToken:$token, Index:$mbarId);
let results = (outs I1:$waitComplete);
- let assemblyFormat = "$barrier `,` $token attr-dict `:` type($barrier) `,` type($token)";
+ let assemblyFormat = "$barriers `[` $mbarId `]` `,` $token attr-dict `:` type($barriers) `,` type($token)";
}
def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier.arrive", []> {
@@ -537,9 +543,9 @@ def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier.arrive", []> {
%token = nvgpu.mbarrier.arrive %barrier : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>> -> !nvgpu.mbarrier.token
```
}];
- let arguments = (ins NVGPU_MBarrier:$barrier);
+ let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$mbarId);
let results = (outs NVGPU_MBarrierToken:$token);
-let assemblyFormat = "$barrier attr-dict `:` type($barrier) `->` type($token)";
+let assemblyFormat = "$barriers `[` $mbarId `]` attr-dict `:` type($barriers) `->` type($token)";
}
def NVGPU_MBarrierArriveNoCompleteOp : NVGPU_Op<"mbarrier.arrive.nocomplete", []> {
@@ -555,10 +561,10 @@ def NVGPU_MBarrierArriveNoCompleteOp : NVGPU_Op<"mbarrier.arrive.nocomplete", []
%token = nvgpu.mbarrier.arrive.noComplete %barrier, %count : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>> -> !nvgpu.mbarrier.token
```
}];
- let arguments = (ins NVGPU_MBarrier:$barrier,
+ let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$mbarId,
Index:$count);
let results = (outs NVGPU_MBarrierToken:$token);
- let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `->` type($token)";
+ let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count attr-dict `:` type($barriers) `->` type($token)";
}
def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> {
@@ -578,9 +584,8 @@ def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> {
nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
```
}];
- let arguments = (ins NVGPU_MBarrier:$barrier,
- Index:$txcount);
- let assemblyFormat = "$barrier `,` $txcount attr-dict `:` type($barrier)";
+ let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$txcount, Index:$mbarId);
+ let assemblyFormat = "$barriers `[` $mbarId `]` `,` $txcount attr-dict `:` type($barriers)";
}
def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
@@ -597,8 +602,8 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
```
}];
- let arguments = (ins NVGPU_MBarrier:$barrier, Index:$phase, Index:$ticks);
- let assemblyFormat = "$barrier `,` $phase `,` $ticks attr-dict `:` type($barrier)";
+ let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$phase, Index:$ticks, Index:$mbarId);
+ let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phase `,` $ticks attr-dict `:` type($barriers)";
}
def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
@@ -613,12 +618,13 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
The Op uses `$barrier` mbarrier based completion mechanism.
}];
let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$dst,
- NVGPU_MBarrier:$barrier,
+ NVGPU_MBarrierGroup:$barriers,
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
- Variadic<Index>:$coordinates);
+ Variadic<Index>:$coordinates,
+ Index:$mbarId);
let assemblyFormat = [{
- $tensorMapDescriptor `[` $coordinates `]` `,` $barrier `to` $dst
- attr-dict `:` type($tensorMapDescriptor) `,` type($barrier) `->` type($dst)
+ $tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]` `to` $dst
+ attr-dict `:` type($tensorMapDescriptor) `,` type($barriers) `->` type($dst)
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..b008572eb443b18 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,8 +17,10 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -212,14 +214,14 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
}
/// Returns whether mbarrier object has shared memory address space.
-static bool isMbarrierShared(nvgpu::MBarrierType barrierType) {
+static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
barrierType.getMemorySpace()));
}
/// Returns the memory space attribute of the mbarrier object.
Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
- nvgpu::MBarrierType barrierType) {
+ nvgpu::MBarrierGroupType barrierType) {
Attribute memorySpace = {};
if (isMbarrierShared(barrierType)) {
memorySpace =
@@ -230,25 +232,13 @@ Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
}
/// Returns memref type of the mbarrier object. The type is defined in the
-/// MBarrierType.
+/// MBarrierGroupType.
MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
- nvgpu::MBarrierType barrierType) {
+ nvgpu::MBarrierGroupType barrierType) {
Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
MemRefLayoutAttrInterface layout;
- return MemRefType::get({1}, IntegerType::get(context, 64), layout,
- memorySpace);
-}
-
-/// Returns the base pointer of the mbarrier object.
-static Value getMbarrierPtr(ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &typeConverter,
- TypedValue<nvgpu::MBarrierType> barrier,
- Value barrierMemref) {
- MemRefType memrefType =
- nvgpu::getMBarrierMemrefType(rewriter.getContext(), barrier.getType());
- MemRefDescriptor memRefDescriptor(barrierMemref);
- return memRefDescriptor.bufferPtr(rewriter, barrier.getLoc(), typeConverter,
- memrefType);
+ return MemRefType::get({barrierType.getNumBarriers()},
+ IntegerType::get(context, 64), layout, memorySpace);
}
namespace {
@@ -426,7 +416,7 @@ struct ConvertNVGPUToNVVMPass
[&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 64));
});
- converter.addConversion([&](nvgpu::MBarrierType type) -> Type {
+ converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
return converter.convertType(
nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
});
@@ -762,7 +752,7 @@ struct NVGPUMBarrierCreateLowering
ConversionPatternRewriter &rewriter) const override {
Operation *funcOp = op->getParentOp();
MemRefType barrierType = nvgpu::getMBarrierMemrefType(
- rewriter.getContext(), op.getBarrier().getType());
+ rewriter.getContext(), op.getBarriers().getType());
memref::GlobalOp global;
if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
@@ -777,21 +767,37 @@ struct NVGPUMBarrierCreateLowering
}
};
+/// Base class for lowering mbarrier operations to nvvm intrinsics.
+template <typename SourceOp>
+struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
+public:
+ using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+ /// Returns the base pointer of the mbarrier object.
+ Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
+ Value memrefDesc, Value mbarId,
+ ConversionPatternRewriter &rewriter) const {
+ MemRefType mbarrierMemrefType =
+ nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
+ return ConvertToLLVMPattern::getStridedElementPtr(
+ op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
+ return memrefDesc;
+ }
+};
+
/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
struct NVGPUMBarrierInitLowering
- : public ConvertOpToLLVMPattern<nvgpu::MBarrierInitOp> {
- using ConvertOpToLLVMPattern<nvgpu::MBarrierInitOp>::ConvertOpToLLVMPattern;
+ : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
+ using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
LogicalResult
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
rewriter.setInsertionPoint(op);
- Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
- op.getBarrier(), adaptor.getBarrier());
-
+ Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
+ adaptor.getMbarId(), rewriter);
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
-
- if (isMbarrierShared(op.getBarrier().getType())) {
+ if (isMbarrierShared(mbarrierType)) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
count);
} else {
@@ -803,16 +809,17 @@ struct NVGPUMBarrierInitLowering
/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
struct NVGPUMBarrierArriveLowering
- : public ConvertOpToLLVMPattern<nvgpu::MBarrierArriveOp> {
- using ConvertOpToLLVMPattern<nvgpu::MBarrierArriveOp>::ConvertOpToLLVMPattern;
+ : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
+ using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
- op.getBarrier(), adaptor.getBarrier());
+ Value barrier =
+ getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
- if (isMbarrierShared(op.getBarrier().getType())) {
+ if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
barrier);
} else {
@@ -826,19 +833,19 @@ struct NVGPUMBarrierArriveLowering
/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
/// `nvvm.mbarrier.arrive.nocomplete`
struct NVGPUMBarrierArriveNoCompleteLowering
- : public ConvertOpToLLVMPattern<nvgpu::MBarrierArriveNoCompleteOp> {
- using ConvertOpToLLVMPattern<
- nvgpu::MBarrierArriveNoCompleteOp>::ConvertOpToLLVMPattern;
-
+ : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
+ using MBarrierBasePattern<
+ nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
- op.getBarrier(), adaptor.getBarrier());
+ Value barrier =
+ getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
- if (isMbarrierShared(op.getBarrier().getType())) {
+ if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
op, tokenType, barrier, count);
} else {
@@ -851,17 +858,16 @@ struct NVGPUMBarrierArriveNoCompleteLowering
/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
struct NVGPUMBarrierTestWaitLowering
- : public ConvertOpToLLVMPattern<nvgpu::MBarrierTestWaitOp> {
- using ConvertOpToLLVMPattern<
- nvgpu::MBarrierTestWaitOp>::ConvertOpToLLVMPattern;
-
+ : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
+ using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
LogicalResult
matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
- op.getBarrier(), adaptor.getBarrier());
+ Value barrier =
+ getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ adaptor.getMbarId(), rewriter);
Type retType = rewriter.getI1Type();
- if (isMbarrierShared(op.getBarrier().getType())) {
+ if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
op, retType, barrier, adaptor.getToken());
} else {
@@ -873,18 +879,18 @@ struct NVGPUMBarrierTestWaitLowering
};
struct NVGPUMBarrierArriveExpectTxLowering
- : public ConvertOpToLLVMPattern<nvgpu::MBarrierArriveExpectTxOp> {
- using ConvertOpToLLVMPattern<
- nvgpu::MBarrierArriveExpectTxOp>::ConvertOpToLLVMPattern;
-
+ : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
+ using MBarrierBasePattern<
+ nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
- op.getBarrier(), adaptor.getBarrier());
+ Value barrier =
+ getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
+ adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
- if (isMbarrierShared(op.getBarrier().getType())) {
+ if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
op, barrier, txcount);
return success();
@@ -897,19 +903,19 @@ struct NVGPUMBarrierArriveExpectTxLowering
};
struct NVGPUMBarrierTryWaitParityLowering
- : public ConvertOpToLLVMPattern<nvgpu::MBarrierTryWaitParityOp> {
- using ConvertOpToLLVMPattern<
- nvgpu::MBarrierTryWaitP...
<truncated>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/65951
More information about the Mlir-commits
mailing list