[flang-commits] [flang] [mlir] [mlir][NVVM] Add ops for vote all and any sync (PR #134309)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Apr 3 16:01:12 PDT 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/134309

>From b8caf553380327c273b8b9b5d17c40f7256cd3c5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 3 Apr 2025 14:26:11 -0700
Subject: [PATCH 1/2] [mlir][NVVM] Add ops for vote all and any sync

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 20 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 22 +++++++++++++++++++--
 mlir/test/Dialect/LLVMIR/nvvm.mlir          |  4 ++++
 3 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..4a549d02dc281 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -819,6 +819,26 @@ def NVVM_VoteBallotOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
+                         Results<(outs LLVM_Type:$res)>,
+                         Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+  string llvmBuilder = [{
+      $res = createIntrinsicCall(builder,
+            llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
+def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
+                         Results<(outs LLVM_Type:$res)>,
+                         Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+  string llvmBuilder = [{
+      $res = createIntrinsicCall(builder,
+            llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
 def NVVM_SyncWarpOp :
   NVVM_Op<"bar.warp.sync">,
   Arguments<(ins LLVM_Type:$mask)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..8ef74fcef90e8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -58,8 +58,7 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
     p << " : " << op->getResultTypes();
 }
 
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
   MLIRContext *context = parser.getContext();
   auto int32Ty = IntegerType::get(context, 32);
   auto int1Ty = IntegerType::get(context, 1);
@@ -74,8 +73,27 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
                                         parser.getNameLoc(), result.operands));
 }
 
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseVoteOps(parser, result);
+}
+
 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
 
+// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
+ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseVoteOps(parser, result);
+}
+
+void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
+// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
+ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseVoteOps(parser, result);
+}
+
+void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
 //===----------------------------------------------------------------------===//
 // Verifier methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 18bf39424f0bf..9eec62ff67561 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -131,6 +131,10 @@ func.func @nvvm_shfl_pred(
 func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
   // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
   %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
+  // CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
+  %1 = nvvm.vote.all.sync %arg0, %arg1 : i32
+  // CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
+  %2 = nvvm.vote.any.sync %arg0, %arg1 : i32
   llvm.return %0 : i32
 }
 

>From 1ade5b5f532a9b0e80c84ef2ec4a0ff6dcf7495a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 3 Apr 2025 15:43:49 -0700
Subject: [PATCH 2/2] Merge ops

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp |  3 +-
 flang/test/Lower/CUDA/cuda-device-proc.cuf    |  2 +-
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 43 ++++++--------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 59 ++++---------------
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 15 +++++
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 14 +++--
 mlir/test/Target/LLVMIR/nvvmir.mlir           |  8 ++-
 7 files changed, 65 insertions(+), 79 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 2df9349269a69..fd4a6c5897364 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6542,7 +6542,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
   mlir::Value arg1 =
       builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
   return builder
-      .create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
+      .create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
+                                      mlir::NVVM::VoteSyncKind::ballot)
       .getResult();
 }
 
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index a7f9038761b51..7d6d920dfb2e8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -303,7 +303,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtestvote()
 ! CHECK: fir.call @llvm.nvvm.vote.all.sync
 ! CHECK: fir.call @llvm.nvvm.vote.any.sync
-! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
+! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
 
 ! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
 ! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4a549d02dc281..f289eb9c1df0f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -808,35 +808,30 @@ def NVVM_ShflOp :
    let hasVerifier = 1;
 }
 
-def NVVM_VoteBallotOp :
-  NVVM_Op<"vote.ballot.sync">,
-  Results<(outs LLVM_Type:$res)>,
-  Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
-  string llvmBuilder = [{
-      $res = createIntrinsicCall(builder,
-            llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
-  }];
-  let hasCustomAssemblyFormat = 1;
+def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
+def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
+def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
+def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
+
+def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
+                               [VoteSyncKindAny, VoteSyncKindAll,
+                                VoteSyncKindBallot, VoteSyncKindUni]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
 }
 
-def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
-                         Results<(outs LLVM_Type:$res)>,
-                         Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
-  string llvmBuilder = [{
-      $res = createIntrinsicCall(builder,
-            llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
-  }];
-  let hasCustomAssemblyFormat = 1;
-}
+def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
 
-def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
-                         Results<(outs LLVM_Type:$res)>,
-                         Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+def NVVM_VoteSyncOp
+    : NVVM_Op<"vote.sync">,
+      Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
+      Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
   string llvmBuilder = [{
-      $res = createIntrinsicCall(builder,
-            llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+    auto intId = getVoteSyncIntrinsicId($kind);
+    $res = createIntrinsicCall(builder, intId, {$mask, $pred});
   }];
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
+  let hasVerifier = 1;
 }
 
 def NVVM_SyncWarpOp :
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8ef74fcef90e8..6ba4c372365db 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -48,52 +48,6 @@ using namespace NVVM;
 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
 
-//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
-  p << " " << op->getOperands();
-  if (op->getNumResults() > 0)
-    p << " : " << op->getResultTypes();
-}
-
-static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
-  MLIRContext *context = parser.getContext();
-  auto int32Ty = IntegerType::get(context, 32);
-  auto int1Ty = IntegerType::get(context, 1);
-
-  SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
-  Type type;
-  return failure(parser.parseOperandList(ops) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(type) ||
-                 parser.addTypeToList(type, result.types) ||
-                 parser.resolveOperands(ops, {int32Ty, int1Ty},
-                                        parser.getNameLoc(), result.operands));
-}
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
-  return parseVoteOps(parser, result);
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
-ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
-  return parseVoteOps(parser, result);
-}
-
-void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
-ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
-  return parseVoteOps(parser, result);
-}
-
-void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
 //===----------------------------------------------------------------------===//
 // Verifier methods
 //===----------------------------------------------------------------------===//
@@ -1178,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::VoteSyncOp::verify() {
+  if (getKind() == NVVM::VoteSyncKind::ballot) {
+    if (!getType().isInteger(32)) {
+      return emitOpError("vote.sync 'ballot' returns an i32");
+    }
+  } else {
+    if (!getType().isInteger(1)) {
+      return emitOpError("match.sync 'any', 'all' and 'uni' returns an i1");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9d14ff09ab434..beff90237562d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
   }
 }
 
+static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
+  switch (kind) {
+  case NVVM::VoteSyncKind::any:
+    return llvm::Intrinsic::nvvm_vote_any_sync;
+  case NVVM::VoteSyncKind::all:
+    return llvm::Intrinsic::nvvm_vote_all_sync;
+  case NVVM::VoteSyncKind::ballot:
+    return llvm::Intrinsic::nvvm_vote_ballot_sync;
+  case NVVM::VoteSyncKind::uni:
+    return llvm::Intrinsic::nvvm_vote_uni_sync;
+  default:
+    llvm_unreachable("unsupported vote kind");
+  }
+}
+
 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
                                                   int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 9eec62ff67561..d3915492c38a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -129,12 +129,14 @@ func.func @nvvm_shfl_pred(
 
 // CHECK-LABEL: @nvvm_vote(
 func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
-  // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
-  %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
-  // CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
-  %1 = nvvm.vote.all.sync %arg0, %arg1 : i32
-  // CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
-  %2 = nvvm.vote.any.sync %arg0, %arg1 : i32
+  // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+  %0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
+  // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+  %1 = nvvm.vote.sync all %arg0, %arg1 -> i1
+  // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+  %2 = nvvm.vote.sync any %arg0, %arg1 -> i1
+  // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+  %3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
   llvm.return %0 : i32
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c3ec88db1d694..3a0713f2feee8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
 // CHECK-LABEL: @nvvm_vote
 llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
   // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
-  %3 = nvvm.vote.ballot.sync %0, %1 : i32
+  %3 = nvvm.vote.sync ballot %0, %1 -> i32
+  // CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
+  %4 = nvvm.vote.sync all %0, %1 -> i1
+  // CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
+  %5 = nvvm.vote.sync any %0, %1 -> i1
+  // CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
+  %6 = nvvm.vote.sync uni %0, %1 -> i1
   llvm.return %3 : i32
 }
 



More information about the flang-commits mailing list