[Mlir-commits] [mlir] c42952a - [MLIR][NVVM] Add support for match.sync Op (#130718)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 18 02:24:27 PDT 2025


Author: Srinivasa Ravi
Date: 2025-03-18T14:54:24+05:30
New Revision: c42952a782a65d7988e3cb81e920662cc97c1b1e

URL: https://github.com/llvm/llvm-project/commit/c42952a782a65d7988e3cb81e920662cc97c1b1e
DIFF: https://github.com/llvm/llvm-project/commit/c42952a782a65d7988e3cb81e920662cc97c1b1e.diff

LOG: [MLIR][NVVM] Add support for match.sync Op (#130718)

This change adds the `match.sync` Op to the MLIR NVVM dialect to
generate the `match.sync` PTX instruction.

PTX Spec Reference:

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-match-sync

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ff6696f6bec40..7c2e042b55248 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -19,6 +19,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Dialect/LLVMIR/LLVMTypes.td"
 
 def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -2583,6 +2584,52 @@ def NVVM_MapaOp: NVVM_Op<"mapa",
   let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM match.sync Op
+//===----------------------------------------------------------------------===//
+
+def MatchSyncKindAny : I32EnumAttrCase<"any", 0>;
+def MatchSyncKindAll : I32EnumAttrCase<"all", 1>;
+
+def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
+  [MatchSyncKindAny, MatchSyncKindAll]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+
+def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
+
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+  Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
+  Arguments<(ins I32:$thread_mask,
+                 AnyTypeOf<[I32, I64]>:$val,
+                 MatchSyncKindAttr:$kind)> {
+  let summary = "Broadcast and compare a value across threads in warp";
+  let description = [{
+    The `match.sync` op performs broadcast and compare of operand `val` across 
+    all non-exited threads in `thread_mask` and returns a mask depending on the 
+    kind and an optional predicate.
+
+    The matching operation kinds are:
+    - `any`: Returns a mask corresponding to the non-exited threads in the 
+    `thread_mask` that have the same value of operand `val`.
+    - `all`: Returns a mask and a predicate. If all non-exited threads in the 
+    `thread_mask` have the same value of operand `val`, the predicate is set to 
+    true and the mask corresponds to the non-exited threads in the 
+    `thread_mask`. Otherwise, the predicate is set to false and the mask is 0.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync)
+  }];
+  string llvmBuilder = [{
+    auto intId = getMatchSyncIntrinsicId(
+        op.getVal().getType(), $kind);
+    $res = createIntrinsicCall(builder,
+        intId, {$thread_mask, $val});
+  }];
+  let assemblyFormat = "$kind $thread_mask `,` $val attr-dict `:` type($val) `->` type($res)";
+  let hasVerifier = 1;
+}
+
 def NVVM_Exit : NVVM_Op<"exit"> {
   let summary = "Exit Op";
   let description = [{

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8f080a2d597a5..ce93cb1ca4297 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1138,6 +1138,22 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::MatchSyncOp::verify() {
+  if (getKind() == NVVM::MatchSyncKind::all) {
+    auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+    if (!Type || Type.getBody().size() != 2 ||
+        !Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
+      return emitOpError("match.sync 'all' returns a two element struct with "
+                         "first element as i32 and second element as i1");
+    }
+  } else {
+    if (!getType().isInteger(32)) {
+      return emitOpError("match.sync 'any' returns an i32");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index c3a129a82688f..74e84a8e8e9a6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -106,6 +106,23 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
   llvm_unreachable("unknown shuffle kind");
 }
 
+static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
+                                                   NVVM::MatchSyncKind kind) {
+  switch (kind) {
+  case NVVM::MatchSyncKind::any:
+    return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
+                                 : llvm::Intrinsic::nvvm_match_any_sync_i64;
+  case NVVM::MatchSyncKind::all:
+    // match.all instruction has two variants -- one returns a single value,
+    // another returns a pair {value, predicate}. We currently only implement
+    // the latter as that's the variant exposed by CUDA API.
+    return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
+                                 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
+  default:
+    llvm_unreachable("unknown match sync kind");
+  }
+}
+
 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
                                                   int32_t num) {

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 85998d4e66254..903dbb1bb5c8e 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -550,6 +550,19 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   return
 }
 
+// CHECK-LABEL: @match_sync
+func.func @match_sync(%val32: i32, %val64: i64, %thread_mask: i32) {
+  // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
+  %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> i32
+  // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+  %1 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
+  %2 = nvvm.match.sync any %thread_mask, %val64 : i64 -> i32
+  // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+  %3 = nvvm.match.sync all %thread_mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+  return 
+}
+
 // -----
 
 // Just check these don't emit errors.

diff  --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 4fca7fd801dbe..97fbbe2fe5fa3 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -152,3 +152,19 @@ llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
   }
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_match_sync_all(%val32: i32, %thread_mask: i32) {
+  // expected-error @below {{match.sync 'all' returns a two element struct with first element as i32 and second element as i1}}
+  %0 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i8)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
+  // expected-error @below {{match.sync 'any' returns an i32}}
+  %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f39aca95b918f..c113cd2fcf5f7 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -810,3 +810,16 @@ llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
   %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
   llvm.return
 }
+
+// CHECK-LABEL: @nvvm_match_sync
+llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
+  // CHECK: call i32 @llvm.nvvm.match.any.sync.i32(i32 %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.match.sync any %mask, %val32 : i32 -> i32
+  // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i32p(i32 %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.match.sync all %mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: call i32 @llvm.nvvm.match.any.sync.i64(i32 %{{.*}}, i64 %{{.*}})
+  %2 = nvvm.match.sync any %mask, %val64 : i64 -> i32
+  // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i64p(i32 %{{.*}}, i64 %{{.*}})
+  %3 = nvvm.match.sync all %mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+  llvm.return
+}


        


More information about the Mlir-commits mailing list