[Mlir-commits] [mlir] [MLIR][NVVM] Add support for match.sync Op (PR #130718)
Srinivasa Ravi
llvmlistbot at llvm.org
Sun Mar 16 21:47:53 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/130718
>From f5de80b0df91cfcacce27efa2500ca6ac8f5e386 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 10 Mar 2025 18:13:51 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for match.sync Op
This change adds the `match.sync` Op to the MLIR NVVM dialect to generate
the `match.sync` PTX instruction.
PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-match-sync
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 47 +++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 16 +++++++
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 17 +++++++
mlir/test/Dialect/LLVMIR/nvvm.mlir | 13 +++++
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 16 +++++++
mlir/test/Target/LLVMIR/nvvmir.mlir | 13 +++++
6 files changed, 122 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 944cb481b025b..6de1bebe261ac 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -19,6 +19,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Dialect/LLVMIR/LLVMTypes.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -2583,6 +2584,52 @@ def NVVM_MapaOp: NVVM_Op<"mapa",
let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
}
+//===----------------------------------------------------------------------===//
+// NVVM match.sync Op
+//===----------------------------------------------------------------------===//
+
+def MatchSyncKindAny : I32EnumAttrCase<"any", 0>;
+def MatchSyncKindAll : I32EnumAttrCase<"all", 1>;
+
+def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
+ [MatchSyncKindAny, MatchSyncKindAll]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
+
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+ Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
+ Arguments<(ins I32:$thread_mask,
+ AnyTypeOf<[I32, I64]>:$val,
+ MatchSyncKindAttr:$kind)> {
+ let summary = "Broadcast and compare a value across threads in warp";
+ let description = [{
+ The `match.sync` op performs broadcast and compare of operand `val` across
+ all non-exited threads in `thread_mask` and returns a mask depending on the
+ kind and an optional predicate.
+
+ The matching operation kinds are:
+ - `any`: Returns a mask corresponding to the non-exited threads in the
+ `thread_mask` that have the same value of operand `val`.
+ - `all`: Returns a mask and a predicate. If all non-exited threads in the
+ `thread_mask` have the same value of operand `val`, the predicate is set to
+ true and the mask corresponds to the non-exited threads in the
+ `thread_mask`. Otherwise, the predicate is set to false and the mask is 0.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync)
+ }];
+ string llvmBuilder = [{
+ auto intId = getMatchSyncIntrinsicId(
+ op.getVal().getType(), $kind);
+ $res = createIntrinsicCall(builder,
+ intId, {$thread_mask, $val});
+ }];
+ let assemblyFormat = "$kind $thread_mask `,` $val attr-dict `:` type($val) `->` type($res)";
+ let hasVerifier = 1;
+}
+
def NVVM_Exit : NVVM_Op<"exit"> {
let summary = "Exit Op";
let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 45a0f9dbd4a7c..737520c8f27f2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1137,6 +1137,22 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
return success();
}
+LogicalResult NVVM::MatchSyncOp::verify() {
+ if (getKind() == NVVM::MatchSyncKind::all) {
+ auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+ if (!Type || Type.getBody().size() != 2 ||
+ !Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
+ return emitOpError("match.sync 'all' returns a two element struct with "
+ "first element as i32 and second element as i1");
+ }
+ } else {
+ if (!getType().isInteger(32)) {
+ return emitOpError("match.sync 'any' returns an i32");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9540762de2777..e93e2f7adda6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -106,6 +106,23 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
llvm_unreachable("unknown shuffle kind");
}
+static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
+ NVVM::MatchSyncKind kind) {
+ switch (kind) {
+ case NVVM::MatchSyncKind::any:
+ return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
+ : llvm::Intrinsic::nvvm_match_any_sync_i64;
+ case NVVM::MatchSyncKind::all:
+ // match.all instruction has two variants -- one returns a single value,
+ // another returns a pair {value, predicate}. We currently only implement
+ // the latter as that's the variant exposed by CUDA API.
+ return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
+ : llvm::Intrinsic::nvvm_match_all_sync_i64p;
+ default:
+ llvm_unreachable("unknown match sync kind");
+ }
+}
+
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 85998d4e66254..903dbb1bb5c8e 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -550,6 +550,19 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
return
}
+// CHECK-LABEL: @match_sync
+func.func @match_sync(%val32: i32, %val64: i64, %thread_mask: i32) {
+ // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
+ %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> i32
+ // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+ %1 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
+ %2 = nvvm.match.sync any %thread_mask, %val64 : i64 -> i32
+ // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+ %3 = nvvm.match.sync all %thread_mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 4fca7fd801dbe..97fbbe2fe5fa3 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -152,3 +152,19 @@ llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
}
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_match_sync_all(%val32: i32, %thread_mask: i32) {
+ // expected-error @below {{match.sync 'all' returns a two element struct with first element as i32 and second element as i1}}
+ %0 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i8)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
+ // expected-error @below {{match.sync 'any' returns an i32}}
+ %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a3a70fcebb7c7..c8008a89813e8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -810,3 +810,16 @@ llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
%7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
llvm.return
}
+
+// CHECK-LABEL: @nvvm_match_sync
+llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
+ // CHECK: call i32 @llvm.nvvm.match.any.sync.i32(i32 %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.match.sync any %mask, %val32 : i32 -> i32
+ // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i32p(i32 %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.match.sync all %mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: call i32 @llvm.nvvm.match.any.sync.i64(i32 %{{.*}}, i64 %{{.*}})
+ %2 = nvvm.match.sync any %mask, %val64 : i64 -> i32
+ // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i64p(i32 %{{.*}}, i64 %{{.*}})
+ %3 = nvvm.match.sync all %mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+ llvm.return
+}
More information about the Mlir-commits
mailing list