[flang-commits] [flang] cd2f85a - [mlir][NVVM] Add ops for vote all and any sync (#134309)
via flang-commits
flang-commits at lists.llvm.org
Fri Apr 4 11:06:13 PDT 2025
Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-04-04T11:06:10-07:00
New Revision: cd2f85a24b55336c96de56276b54d1196fd55fd1
URL: https://github.com/llvm/llvm-project/commit/cd2f85a24b55336c96de56276b54d1196fd55fd1
DIFF: https://github.com/llvm/llvm-project/commit/cd2f85a24b55336c96de56276b54d1196fd55fd1.diff
LOG: [mlir][NVVM] Add ops for vote all and any sync (#134309)
Add operations for `nvvm.vote.all.sync` and `nvvm.vote.any.sync`
intrinsics similar to `nvvm.vote.ballot.sync`.
Added:
Modified:
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/test/Lower/CUDA/cuda-device-proc.cuf
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0ca636bc091ec..702a55a49c953 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6616,7 +6616,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
mlir::Value arg1 =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
return builder
- .create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
+ .create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
+ mlir::NVVM::VoteSyncKind::ballot)
.getResult();
}
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index a7f9038761b51..7d6d920dfb2e8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -303,7 +303,7 @@ end subroutine
! CHECK-LABEL: func.func @_QPtestvote()
! CHECK: fir.call @llvm.nvvm.vote.all.sync
! CHECK: fir.call @llvm.nvvm.vote.any.sync
-! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
+! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..0a6e66919f021 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -808,15 +808,49 @@ def NVVM_ShflOp :
let hasVerifier = 1;
}
-def NVVM_VoteBallotOp :
- NVVM_Op<"vote.ballot.sync">,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
+def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
+def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
+def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
+
+def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
+ [VoteSyncKindAny, VoteSyncKindAll,
+ VoteSyncKindBallot, VoteSyncKindUni]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
+
+def NVVM_VoteSyncOp
+ : NVVM_Op<"vote.sync">,
+ Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
+ Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
+ let summary = "Vote across thread group";
+ let description = [{
+ The `vote.sync` op will cause executing thread to wait until all non-exited
+ threads corresponding to membermask have executed `vote.sync` with the same
+ qualifiers and same membermask value before resuming execution.
+
+ The vote operation kinds are:
+ - `any`: True if source predicate is True for some thread in membermask.
+ - `all`: True if source predicate is True for all non-exited threads in
+ membermask.
+ - `uni`: True if source predicate has the same value in all non-exited
+ threads in membermask.
+ - `ballot`: In the ballot form, the destination result is a 32 bit integer.
+ In this form, the predicate from each thread in membermask are copied into
+ the corresponding bit position of the result, where the bit position
+ corresponds to the thread’s lane id.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync)
+ }];
string llvmBuilder = [{
- $res = createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
+ auto intId = getVoteSyncIntrinsicId($kind);
+ $res = createIntrinsicCall(builder, intId, {$mask, $pred});
}];
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
+ let hasVerifier = 1;
}
def NVVM_SyncWarpOp :
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..09bff6101edd3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -48,34 +48,6 @@ using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
-//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
- p << " " << op->getOperands();
- if (op->getNumResults() > 0)
- p << " : " << op->getResultTypes();
-}
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
- MLIRContext *context = parser.getContext();
- auto int32Ty = IntegerType::get(context, 32);
- auto int1Ty = IntegerType::get(context, 1);
-
- SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
- Type type;
- return failure(parser.parseOperandList(ops) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.addTypeToList(type, result.types) ||
- parser.resolveOperands(ops, {int32Ty, int1Ty},
- parser.getNameLoc(), result.operands));
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -1160,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
return success();
}
+LogicalResult NVVM::VoteSyncOp::verify() {
+ if (getKind() == NVVM::VoteSyncKind::ballot) {
+ if (!getType().isInteger(32)) {
+ return emitOpError("vote.sync 'ballot' returns an i32");
+ }
+ } else {
+ if (!getType().isInteger(1)) {
+ return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9d14ff09ab434..beff90237562d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
}
}
+static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
+ switch (kind) {
+ case NVVM::VoteSyncKind::any:
+ return llvm::Intrinsic::nvvm_vote_any_sync;
+ case NVVM::VoteSyncKind::all:
+ return llvm::Intrinsic::nvvm_vote_all_sync;
+ case NVVM::VoteSyncKind::ballot:
+ return llvm::Intrinsic::nvvm_vote_ballot_sync;
+ case NVVM::VoteSyncKind::uni:
+ return llvm::Intrinsic::nvvm_vote_uni_sync;
+ default:
+ llvm_unreachable("unsupported vote kind");
+ }
+}
+
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 18bf39424f0bf..d3915492c38a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -129,8 +129,14 @@ func.func @nvvm_shfl_pred(
// CHECK-LABEL: @nvvm_vote(
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
- // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
- %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+ %0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
+ // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+ %1 = nvvm.vote.sync all %arg0, %arg1 -> i1
+ // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+ %2 = nvvm.vote.sync any %arg0, %arg1 -> i1
+ // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+ %3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
llvm.return %0 : i32
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c3ec88db1d694..3a0713f2feee8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
// CHECK-LABEL: @nvvm_vote
llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
- %3 = nvvm.vote.ballot.sync %0, %1 : i32
+ %3 = nvvm.vote.sync ballot %0, %1 -> i32
+ // CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
+ %4 = nvvm.vote.sync all %0, %1 -> i1
+ // CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
+ %5 = nvvm.vote.sync any %0, %1 -> i1
+ // CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
+ %6 = nvvm.vote.sync uni %0, %1 -> i1
llvm.return %3 : i32
}
More information about the flang-commits
mailing list