[flang-commits] [flang] [mlir] [mlir][NVVM] Add ops for vote all and any sync (PR #134309)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Apr 3 16:01:12 PDT 2025
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/134309
>From b8caf553380327c273b8b9b5d17c40f7256cd3c5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 3 Apr 2025 14:26:11 -0700
Subject: [PATCH 1/2] [mlir][NVVM] Add ops for vote all and any sync
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 20 +++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 22 +++++++++++++++++++--
mlir/test/Dialect/LLVMIR/nvvm.mlir | 4 ++++
3 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..4a549d02dc281 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -819,6 +819,26 @@ def NVVM_VoteBallotOp :
let hasCustomAssemblyFormat = 1;
}
+def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+ string llvmBuilder = [{
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+ string llvmBuilder = [{
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
def NVVM_SyncWarpOp :
NVVM_Op<"bar.warp.sync">,
Arguments<(ins LLVM_Type:$mask)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..8ef74fcef90e8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -58,8 +58,7 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
MLIRContext *context = parser.getContext();
auto int32Ty = IntegerType::get(context, 32);
auto int1Ty = IntegerType::get(context, 1);
@@ -74,8 +73,27 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(), result.operands));
}
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseVoteOps(parser, result);
+}
+
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
+ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseVoteOps(parser, result);
+}
+
+void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
+// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
+ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseVoteOps(parser, result);
+}
+
+void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 18bf39424f0bf..9eec62ff67561 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -131,6 +131,10 @@ func.func @nvvm_shfl_pred(
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
+ %1 = nvvm.vote.all.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
+ %2 = nvvm.vote.any.sync %arg0, %arg1 : i32
llvm.return %0 : i32
}
>From 1ade5b5f532a9b0e80c84ef2ec4a0ff6dcf7495a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 3 Apr 2025 15:43:49 -0700
Subject: [PATCH 2/2] Merge ops
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 3 +-
flang/test/Lower/CUDA/cuda-device-proc.cuf | 2 +-
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 43 ++++++--------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 59 ++++---------------
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 15 +++++
mlir/test/Dialect/LLVMIR/nvvm.mlir | 14 +++--
mlir/test/Target/LLVMIR/nvvmir.mlir | 8 ++-
7 files changed, 65 insertions(+), 79 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 2df9349269a69..fd4a6c5897364 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6542,7 +6542,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
mlir::Value arg1 =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
return builder
- .create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
+ .create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
+ mlir::NVVM::VoteSyncKind::ballot)
.getResult();
}
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index a7f9038761b51..7d6d920dfb2e8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -303,7 +303,7 @@ end subroutine
! CHECK-LABEL: func.func @_QPtestvote()
! CHECK: fir.call @llvm.nvvm.vote.all.sync
! CHECK: fir.call @llvm.nvvm.vote.any.sync
-! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
+! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4a549d02dc281..f289eb9c1df0f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -808,35 +808,30 @@ def NVVM_ShflOp :
let hasVerifier = 1;
}
-def NVVM_VoteBallotOp :
- NVVM_Op<"vote.ballot.sync">,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
- string llvmBuilder = [{
- $res = createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
- }];
- let hasCustomAssemblyFormat = 1;
+def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
+def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
+def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
+def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
+
+def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
+ [VoteSyncKindAny, VoteSyncKindAll,
+ VoteSyncKindBallot, VoteSyncKindUni]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
}
-def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
- string llvmBuilder = [{
- $res = createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
- }];
- let hasCustomAssemblyFormat = 1;
-}
+def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
-def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+def NVVM_VoteSyncOp
+ : NVVM_Op<"vote.sync">,
+ Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
+ Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
string llvmBuilder = [{
- $res = createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+ auto intId = getVoteSyncIntrinsicId($kind);
+ $res = createIntrinsicCall(builder, intId, {$mask, $pred});
}];
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
+ let hasVerifier = 1;
}
def NVVM_SyncWarpOp :
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8ef74fcef90e8..6ba4c372365db 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -48,52 +48,6 @@ using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
-//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
- p << " " << op->getOperands();
- if (op->getNumResults() > 0)
- p << " : " << op->getResultTypes();
-}
-
-static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
- MLIRContext *context = parser.getContext();
- auto int32Ty = IntegerType::get(context, 32);
- auto int1Ty = IntegerType::get(context, 1);
-
- SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
- Type type;
- return failure(parser.parseOperandList(ops) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.addTypeToList(type, result.types) ||
- parser.resolveOperands(ops, {int32Ty, int1Ty},
- parser.getNameLoc(), result.operands));
-}
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
- return parseVoteOps(parser, result);
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
-ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
- return parseVoteOps(parser, result);
-}
-
-void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
-ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
- return parseVoteOps(parser, result);
-}
-
-void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -1178,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
return success();
}
+LogicalResult NVVM::VoteSyncOp::verify() {
+ if (getKind() == NVVM::VoteSyncKind::ballot) {
+ if (!getType().isInteger(32)) {
+ return emitOpError("vote.sync 'ballot' returns an i32");
+ }
+ } else {
+ if (!getType().isInteger(1)) {
+ return emitOpError("match.sync 'any', 'all' and 'uni' returns an i1");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9d14ff09ab434..beff90237562d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
}
}
+static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
+ switch (kind) {
+ case NVVM::VoteSyncKind::any:
+ return llvm::Intrinsic::nvvm_vote_any_sync;
+ case NVVM::VoteSyncKind::all:
+ return llvm::Intrinsic::nvvm_vote_all_sync;
+ case NVVM::VoteSyncKind::ballot:
+ return llvm::Intrinsic::nvvm_vote_ballot_sync;
+ case NVVM::VoteSyncKind::uni:
+ return llvm::Intrinsic::nvvm_vote_uni_sync;
+ default:
+ llvm_unreachable("unsupported vote kind");
+ }
+}
+
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 9eec62ff67561..d3915492c38a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -129,12 +129,14 @@ func.func @nvvm_shfl_pred(
// CHECK-LABEL: @nvvm_vote(
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
- // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
- %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
- // CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
- %1 = nvvm.vote.all.sync %arg0, %arg1 : i32
- // CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
- %2 = nvvm.vote.any.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+ %0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
+ // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+ %1 = nvvm.vote.sync all %arg0, %arg1 -> i1
+ // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+ %2 = nvvm.vote.sync any %arg0, %arg1 -> i1
+ // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+ %3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
llvm.return %0 : i32
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c3ec88db1d694..3a0713f2feee8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
// CHECK-LABEL: @nvvm_vote
llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
- %3 = nvvm.vote.ballot.sync %0, %1 : i32
+ %3 = nvvm.vote.sync ballot %0, %1 -> i32
+ // CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
+ %4 = nvvm.vote.sync all %0, %1 -> i1
+ // CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
+ %5 = nvvm.vote.sync any %0, %1 -> i1
+ // CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
+ %6 = nvvm.vote.sync uni %0, %1 -> i1
llvm.return %3 : i32
}
More information about the flang-commits
mailing list