[Mlir-commits] [mlir] [MLIR][NVVM] Add support for match.sync Op (PR #130718)

Srinivasa Ravi llvmlistbot at llvm.org
Mon Mar 10 22:18:24 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/130718

>From d3d080f3213e84b5c62b2e0af9cd483c1f3d665a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 10 Mar 2025 18:13:51 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for match.sync Op

This change adds the `match.sync` Op to the MLIR NVVM dialect to generate
the `match.sync` PTX instruction.

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-match-sync
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 46 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 16 +++++++
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 17 +++++++
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 29 ++++++++++++
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 13 ++++++
 5 files changed, 121 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 944cb481b025b..161c820ac4b03 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2583,6 +2583,52 @@ def NVVM_MapaOp: NVVM_Op<"mapa",
   let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM match.sync Op
+//===----------------------------------------------------------------------===//
+
+def MatchSyncKindAny : I32EnumAttrCase<"any", 0>;
+def MatchSyncKindAll : I32EnumAttrCase<"all", 1>;
+
+def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
+  [MatchSyncKindAny, MatchSyncKindAll]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+
+def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
+
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins I32:$thread_mask,
+                 AnyTypeOf<[I32, I64]>:$val,
+                 MatchSyncKindAttr:$kind)> {
+  let summary = "NVVM Dialect Op for match.sync";
+  let description = [{
+    The `match.sync` op performs broadcast and compare of operand `val` across 
+    all non-exited threads in `thread_mask` and returns a mask depending on the 
+    kind and an optional predicate.
+
+    The matching operation kinds are:
+    - `any`: Returns mask of non-exited threads in the `thread_mask` that have 
+    same value of operand `val`.
+    - `all`: Returns mask corresponding to non-exited threads in the 
+    `thread_mask` and a predicate set to true if all non-exited threads in the 
+    `thread_mask` have same value of operand `val`; otherwise returns 0 and a 
+    predicate set to false.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync)
+  }];
+  string llvmBuilder = [{
+    auto intId = getMatchSyncIntrinsicId(
+        op.getVal().getType(), $kind);
+    $res = createIntrinsicCall(builder,
+        intId, {$thread_mask, $val});
+  }];
+  let assemblyFormat = "$kind $thread_mask `,` $val attr-dict `:` type($val) `->` type($res)";
+  let hasVerifier = 1;
+}
+
 def NVVM_Exit : NVVM_Op<"exit"> {
   let summary = "Exit Op";
   let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 45a0f9dbd4a7c..737520c8f27f2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1137,6 +1137,22 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::MatchSyncOp::verify() {
+  if (getKind() == NVVM::MatchSyncKind::all) {
+    auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+    if (!Type || Type.getBody().size() != 2 ||
+        !Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
+      return emitOpError("match.sync 'all' returns a two element struct with "
+                         "first element as i32 and second element as i1");
+    }
+  } else {
+    if (!getType().isInteger(32)) {
+      return emitOpError("match.sync 'any' returns an i32");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9540762de2777..e93e2f7adda6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -106,6 +106,23 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
   llvm_unreachable("unknown shuffle kind");
 }
 
+static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
+                                                   NVVM::MatchSyncKind kind) {
+  switch (kind) {
+  case NVVM::MatchSyncKind::any:
+    return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
+                                 : llvm::Intrinsic::nvvm_match_any_sync_i64;
+  case NVVM::MatchSyncKind::all:
+    // match.all instruction has two variants -- one returns a single value,
+    // another returns a pair {value, predicate}. We currently only implement
+    // the latter as that's the variant exposed by CUDA API.
+    return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
+                                 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
+  default:
+    llvm_unreachable("unknown match sync kind");
+  }
+}
+
 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
                                                   int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 85998d4e66254..cab0311cc4555 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -550,6 +550,35 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   return
 }
 
+// CHECK-LABEL: @match_sync
+func.func @match_sync(%val32: i32, %val64: i64, %thread_mask: i32) {
+  // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
+  %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> i32
+  // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+  %1 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
+  %2 = nvvm.match.sync any %thread_mask, %val64 : i64 -> i32
+  // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+  %3 = nvvm.match.sync all %thread_mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+  return 
+}
+
+// -----
+
+func.func @match_sync_error1(%val32: i32, %thread_mask: i32) {
+  // expected-error @below {{match.sync 'all' returns a two element struct with first element as i32 and second element as i1}}
+  %0 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i8)>
+  return
+}
+
+// -----
+
+func.func @match_sync_error2(%val32: i32, %thread_mask: i32) {
+  // expected-error @below {{match.sync 'any' returns an i32}}
+  %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a3a70fcebb7c7..c8008a89813e8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -810,3 +810,16 @@ llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
   %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
   llvm.return
 }
+
+// CHECK-LABEL: @nvvm_match_sync
+llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
+  // CHECK: call i32 @llvm.nvvm.match.any.sync.i32(i32 %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.match.sync any %mask, %val32 : i32 -> i32
+  // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i32p(i32 %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.match.sync all %mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: call i32 @llvm.nvvm.match.any.sync.i64(i32 %{{.*}}, i64 %{{.*}})
+  %2 = nvvm.match.sync any %mask, %val64 : i64 -> i32
+  // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i64p(i32 %{{.*}}, i64 %{{.*}})
+  %3 = nvvm.match.sync all %mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+  llvm.return
+}



More information about the Mlir-commits mailing list