[flang-commits] [flang] cd2f85a - [mlir][NVVM] Add ops for vote all and any sync (#134309)

via flang-commits flang-commits at lists.llvm.org
Fri Apr 4 11:06:13 PDT 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-04-04T11:06:10-07:00
New Revision: cd2f85a24b55336c96de56276b54d1196fd55fd1

URL: https://github.com/llvm/llvm-project/commit/cd2f85a24b55336c96de56276b54d1196fd55fd1
DIFF: https://github.com/llvm/llvm-project/commit/cd2f85a24b55336c96de56276b54d1196fd55fd1.diff

LOG: [mlir][NVVM] Add ops for vote all and any sync (#134309)

Add operations for `nvvm.vote.all.sync` and `nvvm.vote.any.sync`
intrinsics similar to `nvvm.vote.ballot.sync`.

Added: 
    

Modified: 
    flang/lib/Optimizer/Builder/IntrinsicCall.cpp
    flang/test/Lower/CUDA/cuda-device-proc.cuf
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0ca636bc091ec..702a55a49c953 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6616,7 +6616,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
   mlir::Value arg1 =
       builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
   return builder
-      .create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
+      .create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
+                                      mlir::NVVM::VoteSyncKind::ballot)
       .getResult();
 }
 

diff  --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index a7f9038761b51..7d6d920dfb2e8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -303,7 +303,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtestvote()
 ! CHECK: fir.call @llvm.nvvm.vote.all.sync
 ! CHECK: fir.call @llvm.nvvm.vote.any.sync
-! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
+! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
 
 ! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
 ! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..0a6e66919f021 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -808,15 +808,49 @@ def NVVM_ShflOp :
    let hasVerifier = 1;
 }
 
-def NVVM_VoteBallotOp :
-  NVVM_Op<"vote.ballot.sync">,
-  Results<(outs LLVM_Type:$res)>,
-  Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
+def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
+def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
+def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
+
+def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
+                               [VoteSyncKindAny, VoteSyncKindAll,
+                                VoteSyncKindBallot, VoteSyncKindUni]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+
+def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
+
+def NVVM_VoteSyncOp
+    : NVVM_Op<"vote.sync">,
+      Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
+      Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
+  let summary = "Vote across thread group";
+  let description = [{
+    The `vote.sync` op will cause executing thread to wait until all non-exited
+    threads corresponding to membermask have executed `vote.sync` with the same
+    qualifiers and same membermask value before resuming execution.
+
+    The vote operation kinds are:
+    - `any`: True if source predicate is True for some thread in membermask.
+    - `all`: True if source predicate is True for all non-exited threads in
+      membermask. 
+    - `uni`: True if source predicate has the same value in all non-exited
+      threads in membermask.
+    - `ballot`: In the ballot form, the destination result is a 32 bit integer.
+      In this form, the predicate from each thread in membermask are copied into
+      the corresponding bit position of the result, where the bit position
+      corresponds to the thread’s lane id.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync)
+  }];
   string llvmBuilder = [{
-      $res = createIntrinsicCall(builder,
-            llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
+    auto intId = getVoteSyncIntrinsicId($kind);
+    $res = createIntrinsicCall(builder, intId, {$mask, $pred});
   }];
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
+  let hasVerifier = 1;
 }
 
 def NVVM_SyncWarpOp :

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..09bff6101edd3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -48,34 +48,6 @@ using namespace NVVM;
 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
 
-//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
-  p << " " << op->getOperands();
-  if (op->getNumResults() > 0)
-    p << " : " << op->getResultTypes();
-}
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
-  MLIRContext *context = parser.getContext();
-  auto int32Ty = IntegerType::get(context, 32);
-  auto int1Ty = IntegerType::get(context, 1);
-
-  SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
-  Type type;
-  return failure(parser.parseOperandList(ops) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(type) ||
-                 parser.addTypeToList(type, result.types) ||
-                 parser.resolveOperands(ops, {int32Ty, int1Ty},
-                                        parser.getNameLoc(), result.operands));
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
 //===----------------------------------------------------------------------===//
 // Verifier methods
 //===----------------------------------------------------------------------===//
@@ -1160,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::VoteSyncOp::verify() {
+  if (getKind() == NVVM::VoteSyncKind::ballot) {
+    if (!getType().isInteger(32)) {
+      return emitOpError("vote.sync 'ballot' returns an i32");
+    }
+  } else {
+    if (!getType().isInteger(1)) {
+      return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9d14ff09ab434..beff90237562d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
   }
 }
 
+static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
+  switch (kind) {
+  case NVVM::VoteSyncKind::any:
+    return llvm::Intrinsic::nvvm_vote_any_sync;
+  case NVVM::VoteSyncKind::all:
+    return llvm::Intrinsic::nvvm_vote_all_sync;
+  case NVVM::VoteSyncKind::ballot:
+    return llvm::Intrinsic::nvvm_vote_ballot_sync;
+  case NVVM::VoteSyncKind::uni:
+    return llvm::Intrinsic::nvvm_vote_uni_sync;
+  default:
+    llvm_unreachable("unsupported vote kind");
+  }
+}
+
 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
                                                   int32_t num) {

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 18bf39424f0bf..d3915492c38a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -129,8 +129,14 @@ func.func @nvvm_shfl_pred(
 
 // CHECK-LABEL: @nvvm_vote(
 func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
-  // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
-  %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
+  // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+  %0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
+  // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+  %1 = nvvm.vote.sync all %arg0, %arg1 -> i1
+  // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+  %2 = nvvm.vote.sync any %arg0, %arg1 -> i1
+  // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+  %3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
   llvm.return %0 : i32
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c3ec88db1d694..3a0713f2feee8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
 // CHECK-LABEL: @nvvm_vote
 llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
   // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
-  %3 = nvvm.vote.ballot.sync %0, %1 : i32
+  %3 = nvvm.vote.sync ballot %0, %1 -> i32
+  // CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
+  %4 = nvvm.vote.sync all %0, %1 -> i1
+  // CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
+  %5 = nvvm.vote.sync any %0, %1 -> i1
+  // CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
+  %6 = nvvm.vote.sync uni %0, %1 -> i1
   llvm.return %3 : i32
 }
 


        


More information about the flang-commits mailing list