[Mlir-commits] [mlir] [MLIR][NVVM] Adds an explicit aligned boolean attribute to nvvm.barrier (PR #192203)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 18 19:13:42 PDT 2026
https://github.com/xys-syx updated https://github.com/llvm/llvm-project/pull/192203
>From 4fe2004cd01c3a7cd0d380c7f72f092d5b5de7fb Mon Sep 17 00:00:00 2001
From: Yuansui Xu <xuyuansui at outlook.com>
Date: Wed, 15 Apr 2026 02:24:53 -0500
Subject: [PATCH 1/3] add align attr for nvvm barrier op
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 18 ++--
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 87 +++++++++++++++----
.../Dialect/LLVMIR/nvvm-canonicalize.mlir | 34 ++++++++
mlir/test/Dialect/LLVMIR/nvvm.mlir | 7 ++
mlir/test/Target/LLVMIR/nvvm/barrier.mlir | 21 +++++
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 7 --
6 files changed, 146 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9b2a8985a1a44..8a8a8722d7d76 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1138,13 +1138,16 @@ def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
within a CTA (Cooperative Thread Array). It causes executing threads to wait for
all non-exited threads participating in the barrier to arrive.
- The operation takes two optional operands:
+ The operation takes the following optional operands and attributes:
- `barrierId`: Specifies a logical barrier resource with value 0 through 15.
Each CTA instance has sixteen barriers numbered 0..15. Defaults to 0 if not specified.
- `numberOfThreads`: Specifies the number of threads participating in the barrier.
When specified, the value must be a multiple of the warp size. If not specified,
all threads in the CTA participate in the barrier.
+ - `aligned`: When `true` (the default), lowers to the `.aligned` form of the
+ underlying `@llvm.nvvm.barrier.cta.*` intrinsic family. When `false`, lowers
+ to the non-`.aligned` form.
- `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`).
- `reductionPredicate`: specifies the predicate to be used with the
`reductionOp`.
@@ -1157,7 +1160,7 @@ def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
When a barrier completes, the waiting threads are restarted without delay, and
the barrier is reinitialized so that it can be immediately reused.
- This operation generates an aligned barrier, indicating that all threads in the CTA
+ By default, this operation generates an aligned barrier, indicating that all threads in the CTA
will execute the same barrier instruction. Behavior is undefined if all threads in the
CTA do not reach this instruction.
@@ -1166,7 +1169,8 @@ def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
OptionalAttr<BarrierReductionAttr>:$reductionOp,
- Optional<I32>:$reductionPredicate);
+ Optional<I32>:$reductionPredicate,
+ DefaultValuedAttr<BoolAttr, "true">:$aligned);
string llvmBuilder = [{
auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
@@ -1179,16 +1183,20 @@ def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
let results = (outs Optional<I32>:$res);
let hasVerifier = 1;
+ let hasCanonicalizeMethod = 1;
let assemblyFormat =
"(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
+ "custom<Aligned>($aligned) "
"(qualified($reductionOp)^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
let builders = [OpBuilder<(ins), [{
- return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});
+ return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{},
+ $_builder.getBoolAttr(true));
}]>,
OpBuilder<(ins "Value":$barrierId), [{
- return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{});
+ return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{},
+ $_builder.getBoolAttr(true));
}]>];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 31e7ff209db5c..f2fdbe5bf0822 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
@@ -2916,9 +2917,8 @@ LogicalResult NVVM::SetMaxRegisterOp::verify() {
}
LogicalResult NVVM::BarrierOp::verify() {
- if (getNumberOfThreads() && !getBarrierId())
- return emitOpError(
- "barrier id is missing, it should be set between 0 to 15");
+ if (getReductionOp() && getNumberOfThreads())
+ return emitOpError("reduction cannot be combined with number_of_threads");
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
return emitOpError("reduction are only available when id is 0");
@@ -2943,6 +2943,23 @@ bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return isCompatibleReturnTypesOptionalResult(l, r);
}
+/// Folds `nvvm.barrier id = %c0` (constant zero) into the form with
+/// `barrierId` omitted.
+LogicalResult NVVM::BarrierOp::canonicalize(NVVM::BarrierOp op,
+ PatternRewriter &rewriter) {
+ Value id = op.getBarrierId();
+ if (!id)
+ return failure();
+
+ APInt value;
+ if (!matchPattern(id, m_ConstantInt(&value)) || !value.isZero())
+ return failure();
+
+ rewriter.modifyOpInPlace(op,
+ [&]() { op.getBarrierIdMutable().clear(); });
+ return success();
+}
+
LogicalResult NVVM::Tcgen05CpOp::verify() {
auto mc = getMulticast();
@@ -3432,6 +3449,37 @@ void SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
+/// Returns the LLVM intrinsic ID for the `nvvm.barrier.cta.sync[.aligned]
+/// .{all,count}` variant matching `aligned` and whether a thread count is
+/// supplied.
+static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned, bool hasCount) {
+ if (hasCount)
+ return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count
+ : llvm::Intrinsic::nvvm_barrier_cta_sync_count;
+ return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all
+ : llvm::Intrinsic::nvvm_barrier_cta_sync_all;
+}
+
+/// Returns the LLVM intrinsic ID for the `nvvm.barrier.cta.red.{and,or,popc}
+/// [.aligned].all` variant matching `aligned` and `red`. Only the `.all` shape
+/// is modeled: `BarrierOp::verify` rejects `reduction + numberOfThreads`, so
+/// the `.red.*.count` intrinsics are intentionally unreachable here.
+static llvm::Intrinsic::ID
+getBarrierReductionIntrinsic(bool aligned, NVVM::BarrierReduction red) {
+ switch (red) {
+ case NVVM::BarrierReduction::AND:
+ return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all
+ : llvm::Intrinsic::nvvm_barrier_cta_red_and_all;
+ case NVVM::BarrierReduction::OR:
+ return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all
+ : llvm::Intrinsic::nvvm_barrier_cta_red_or_all;
+ case NVVM::BarrierReduction::POPC:
+ return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all
+ : llvm::Intrinsic::nvvm_barrier_cta_red_popc_all;
+ }
+ llvm_unreachable("unknown BarrierReduction kind");
+}
+
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::BarrierOp>(op);
@@ -3441,24 +3489,15 @@ mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
llvm::Intrinsic::ID id;
llvm::SmallVector<llvm::Value *> args = {barrierId};
if (thisOp.getNumberOfThreads()) {
- id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
+ id = getBarrierSyncIntrinsic(thisOp.getAligned(), /*hasCount=*/true);
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
} else if (thisOp.getReductionOp()) {
- switch (*thisOp.getReductionOp()) {
- case NVVM::BarrierReduction::AND:
- id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
- break;
- case NVVM::BarrierReduction::OR:
- id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
- break;
- case NVVM::BarrierReduction::POPC:
- id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
- break;
- }
+ id = getBarrierReductionIntrinsic(thisOp.getAligned(),
+ *thisOp.getReductionOp());
args.push_back(builder.CreateICmpNE(
mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
} else {
- id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
+ id = getBarrierSyncIntrinsic(thisOp.getAligned(), /*hasCount=*/false);
}
return {id, std::move(args)};
@@ -6228,6 +6267,22 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
return success();
}
+/// Parses the `non_aligned` keyword marker for `nvvm.barrier`.
+static ParseResult parseAligned(OpAsmParser &parser, BoolAttr &aligned) {
+ bool isNonAligned =
+ succeeded(parser.parseOptionalKeyword("non_aligned"));
+ aligned = parser.getBuilder().getBoolAttr(!isNonAligned);
+ return success();
+}
+
+/// Prints the `non_aligned` keyword marker for `nvvm.barrier` when the op is
+/// non-aligned. Nothing is printed for the default aligned case.
+static void printAligned(OpAsmPrinter &printer, Operation *op,
+ BoolAttr aligned) {
+ if (!aligned.getValue())
+ printer << "non_aligned";
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-canonicalize.mlir b/mlir/test/Dialect/LLVMIR/nvvm-canonicalize.mlir
index fe9afd840bab2..e6a2d05b89ef9 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-canonicalize.mlir
@@ -7,3 +7,37 @@ llvm.func @subf_canonicalize(%arg0 : f32, %arg1 : f32) -> f32 {
%0 = nvvm.subf %arg0, %arg1 : f32
llvm.return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: @nvvm_barrier_fold_id_zero
+llvm.func @nvvm_barrier_fold_id_zero() {
+ // CHECK-NOT: llvm.mlir.constant
+ // CHECK: nvvm.barrier
+ // CHECK-NOT: id =
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ nvvm.barrier id = %c0
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_barrier_keep_id_nonzero
+llvm.func @nvvm_barrier_keep_id_nonzero() {
+ // CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : i32) : i32
+ // CHECK: nvvm.barrier id = %[[C5]]
+ %c5 = llvm.mlir.constant(5 : i32) : i32
+ nvvm.barrier id = %c5
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_barrier_fold_id_zero_with_count
+llvm.func @nvvm_barrier_fold_id_zero_with_count(%n : i32) {
+ // CHECK: nvvm.barrier number_of_threads = %{{.*}}
+ // CHECK-NOT: id =
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ nvvm.barrier id = %c0 number_of_threads = %n
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c039edc6b5de5..c19be68492db3 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -48,10 +48,17 @@ func.func @llvm_nvvm_barrier0() {
llvm.func @llvm_nvvm_barrier(%barId : i32, %numberOfThreads : i32) {
// CHECK: nvvm.barrier
nvvm.barrier
+ // CHECK-NOT: aligned
// CHECK: nvvm.barrier id = %[[barId]]
nvvm.barrier id = %barId
// CHECK: nvvm.barrier id = %[[barId]] number_of_threads = %[[numberOfThreads]]
nvvm.barrier id = %barId number_of_threads = %numberOfThreads
+ // CHECK: nvvm.barrier number_of_threads = %[[numberOfThreads]]
+ nvvm.barrier number_of_threads = %numberOfThreads
+ // CHECK: nvvm.barrier non_aligned
+ nvvm.barrier non_aligned
+ // CHECK: nvvm.barrier id = %[[barId]] non_aligned
+ nvvm.barrier id = %barId non_aligned
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
index 7e654eb8dc572..d54891000f2e1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
@@ -25,6 +25,27 @@ llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand :
// LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier.cta.red.popc.aligned.all(i32 0, i1 %[[redOperandCmp3]])
// CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %{{.*}} -> i32
%2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.all(i32 0)
+ // CHECK: nvvm.barrier non_aligned
+ nvvm.barrier non_aligned
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.all(i32 %[[barId]])
+ // CHECK: nvvm.barrier id = %{{.*}} non_aligned
+ nvvm.barrier id = %barID non_aligned
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 0, i32 %[[numThreads]])
+ // CHECK: nvvm.barrier number_of_threads = %{{.*}}
+ nvvm.barrier number_of_threads = %numberOfThreads
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.count(i32 %[[barId]], i32 %[[numThreads]])
+ // CHECK: nvvm.barrier id = %{{.*}} number_of_threads = %{{.*}} non_aligned
+ nvvm.barrier id = %barID number_of_threads = %numberOfThreads non_aligned
+ // LLVM: %{{.*}} = call i1 @llvm.nvvm.barrier.cta.red.and.all(i32 0, i1 %{{.*}})
+ // CHECK: %{{.*}} = nvvm.barrier non_aligned #nvvm.reduction<and> %{{.*}} -> i32
+ %3 = nvvm.barrier non_aligned #nvvm.reduction<and> %redOperand -> i32
+ // LLVM: %{{.*}} = call i1 @llvm.nvvm.barrier.cta.red.or.all(i32 0, i1 %{{.*}})
+ // CHECK: %{{.*}} = nvvm.barrier non_aligned #nvvm.reduction<or> %{{.*}} -> i32
+ %4 = nvvm.barrier non_aligned #nvvm.reduction<or> %redOperand -> i32
+ // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier.cta.red.popc.all(i32 0, i1 %{{.*}})
+ // CHECK: %{{.*}} = nvvm.barrier non_aligned #nvvm.reduction<popc> %{{.*}} -> i32
+ %5 = nvvm.barrier non_aligned #nvvm.reduction<popc> %redOperand -> i32
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 2726fc7a40ef0..0690dd4700e2d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -1,12 +1,5 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
-llvm.func @kernel_func(%numberOfThreads : i32) {
- // expected-error @below {{'nvvm.barrier' op barrier id is missing, it should be set between 0 to 15}}
- nvvm.barrier number_of_threads = %numberOfThreads
-}
-
-// -----
-
// expected-error @below {{'"nvvm.minctasm"' attribute must be integer constant}}
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = "foo"} {
llvm.return
>From 74b7204859fe72e4bf7496b1fc36e32b3c53b280 Mon Sep 17 00:00:00 2001
From: Yuansui Xu <xuyuansui at outlook.com>
Date: Wed, 15 Apr 2026 03:00:28 -0500
Subject: [PATCH 2/3] fmt
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f2fdbe5bf0822..40e1cf1671d4d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2955,8 +2955,7 @@ LogicalResult NVVM::BarrierOp::canonicalize(NVVM::BarrierOp op,
if (!matchPattern(id, m_ConstantInt(&value)) || !value.isZero())
return failure();
- rewriter.modifyOpInPlace(op,
- [&]() { op.getBarrierIdMutable().clear(); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getBarrierIdMutable().clear(); });
return success();
}
@@ -3452,7 +3451,8 @@ void SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
/// Returns the LLVM intrinsic ID for the `nvvm.barrier.cta.sync[.aligned]
/// .{all,count}` variant matching `aligned` and whether a thread count is
/// supplied.
-static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned, bool hasCount) {
+static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned,
+ bool hasCount) {
if (hasCount)
return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count
: llvm::Intrinsic::nvvm_barrier_cta_sync_count;
@@ -6269,8 +6269,7 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
/// Parses the `non_aligned` keyword marker for `nvvm.barrier`.
static ParseResult parseAligned(OpAsmParser &parser, BoolAttr &aligned) {
- bool isNonAligned =
- succeeded(parser.parseOptionalKeyword("non_aligned"));
+ bool isNonAligned = succeeded(parser.parseOptionalKeyword("non_aligned"));
aligned = parser.getBuilder().getBoolAttr(!isNonAligned);
return success();
}
>From cfe066d279bc0f3d91f490efb1625cf3757e6faa Mon Sep 17 00:00:00 2001
From: Yuansui Xu <xuyuansui at outlook.com>
Date: Sat, 18 Apr 2026 21:13:28 -0500
Subject: [PATCH 3/3] rm reduction part and add verifier
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 43 ++++++++++------------
mlir/test/Target/LLVMIR/nvvm/barrier.mlir | 9 -----
2 files changed, 19 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 40e1cf1671d4d..6acd197010988 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2917,8 +2917,14 @@ LogicalResult NVVM::SetMaxRegisterOp::verify() {
}
LogicalResult NVVM::BarrierOp::verify() {
- if (getReductionOp() && getNumberOfThreads())
- return emitOpError("reduction cannot be combined with number_of_threads");
+ // Temporary: the reduction lowering path still hardcodes the `.aligned`
+ // intrinsic spelling. A follow-up PR will extend it along the aligned axis
+ // (and add the `.red.*.count` variants); until then, reject non-aligned
+ // reductions here so users cannot construct an op whose `aligned = false`
+ // would be silently dropped during lowering.
+ if (getReductionOp() && !getAligned())
+ return emitOpError(
+ "non-aligned reduction is not supported yet; tracked as a follow-up");
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
return emitOpError("reduction are only available when id is 0");
@@ -3460,26 +3466,6 @@ static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned,
: llvm::Intrinsic::nvvm_barrier_cta_sync_all;
}
-/// Returns the LLVM intrinsic ID for the `nvvm.barrier.cta.red.{and,or,popc}
-/// [.aligned].all` variant matching `aligned` and `red`. Only the `.all` shape
-/// is modeled: `BarrierOp::verify` rejects `reduction + numberOfThreads`, so
-/// the `.red.*.count` intrinsics are intentionally unreachable here.
-static llvm::Intrinsic::ID
-getBarrierReductionIntrinsic(bool aligned, NVVM::BarrierReduction red) {
- switch (red) {
- case NVVM::BarrierReduction::AND:
- return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all
- : llvm::Intrinsic::nvvm_barrier_cta_red_and_all;
- case NVVM::BarrierReduction::OR:
- return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all
- : llvm::Intrinsic::nvvm_barrier_cta_red_or_all;
- case NVVM::BarrierReduction::POPC:
- return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all
- : llvm::Intrinsic::nvvm_barrier_cta_red_popc_all;
- }
- llvm_unreachable("unknown BarrierReduction kind");
-}
-
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::BarrierOp>(op);
@@ -3492,8 +3478,17 @@ mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
id = getBarrierSyncIntrinsic(thisOp.getAligned(), /*hasCount=*/true);
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
} else if (thisOp.getReductionOp()) {
- id = getBarrierReductionIntrinsic(thisOp.getAligned(),
- *thisOp.getReductionOp());
+ switch (*thisOp.getReductionOp()) {
+ case NVVM::BarrierReduction::AND:
+ id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
+ break;
+ case NVVM::BarrierReduction::OR:
+ id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
+ break;
+ case NVVM::BarrierReduction::POPC:
+ id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
+ break;
+ }
args.push_back(builder.CreateICmpNE(
mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
} else {
diff --git a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
index d54891000f2e1..1d934e61fd4e2 100644
--- a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
@@ -37,15 +37,6 @@ llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand :
// LLVM: call void @llvm.nvvm.barrier.cta.sync.count(i32 %[[barId]], i32 %[[numThreads]])
// CHECK: nvvm.barrier id = %{{.*}} number_of_threads = %{{.*}} non_aligned
nvvm.barrier id = %barID number_of_threads = %numberOfThreads non_aligned
- // LLVM: %{{.*}} = call i1 @llvm.nvvm.barrier.cta.red.and.all(i32 0, i1 %{{.*}})
- // CHECK: %{{.*}} = nvvm.barrier non_aligned #nvvm.reduction<and> %{{.*}} -> i32
- %3 = nvvm.barrier non_aligned #nvvm.reduction<and> %redOperand -> i32
- // LLVM: %{{.*}} = call i1 @llvm.nvvm.barrier.cta.red.or.all(i32 0, i1 %{{.*}})
- // CHECK: %{{.*}} = nvvm.barrier non_aligned #nvvm.reduction<or> %{{.*}} -> i32
- %4 = nvvm.barrier non_aligned #nvvm.reduction<or> %redOperand -> i32
- // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier.cta.red.popc.all(i32 0, i1 %{{.*}})
- // CHECK: %{{.*}} = nvvm.barrier non_aligned #nvvm.reduction<popc> %{{.*}} -> i32
- %5 = nvvm.barrier non_aligned #nvvm.reduction<popc> %redOperand -> i32
llvm.return
}
More information about the Mlir-commits
mailing list