[Mlir-commits] [mlir] c42952a - [MLIR][NVVM] Add support for match.sync Op (#130718)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 18 02:24:27 PDT 2025
Author: Srinivasa Ravi
Date: 2025-03-18T14:54:24+05:30
New Revision: c42952a782a65d7988e3cb81e920662cc97c1b1e
URL: https://github.com/llvm/llvm-project/commit/c42952a782a65d7988e3cb81e920662cc97c1b1e
DIFF: https://github.com/llvm/llvm-project/commit/c42952a782a65d7988e3cb81e920662cc97c1b1e.diff
LOG: [MLIR][NVVM] Add support for match.sync Op (#130718)
This change adds the `match.sync` Op to the MLIR NVVM dialect to
generate the `match.sync` PTX instruction.
PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-match-sync
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ff6696f6bec40..7c2e042b55248 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -19,6 +19,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Dialect/LLVMIR/LLVMTypes.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -2583,6 +2584,52 @@ def NVVM_MapaOp: NVVM_Op<"mapa",
let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
}
+//===----------------------------------------------------------------------===//
+// NVVM match.sync Op
+//===----------------------------------------------------------------------===//
+
+def MatchSyncKindAny : I32EnumAttrCase<"any", 0>;
+def MatchSyncKindAll : I32EnumAttrCase<"all", 1>;
+
+def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
+ [MatchSyncKindAny, MatchSyncKindAll]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
+
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+ Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
+ Arguments<(ins I32:$thread_mask,
+ AnyTypeOf<[I32, I64]>:$val,
+ MatchSyncKindAttr:$kind)> {
+ let summary = "Broadcast and compare a value across threads in warp";
+ let description = [{
+ The `match.sync` op performs broadcast and compare of operand `val` across
+ all non-exited threads in `thread_mask` and returns a mask depending on the
+ kind and an optional predicate.
+
+ The matching operation kinds are:
+ - `any`: Returns a mask corresponding to the non-exited threads in the
+ `thread_mask` that have the same value of operand `val`.
+ - `all`: Returns a mask and a predicate. If all non-exited threads in the
+ `thread_mask` have the same value of operand `val`, the predicate is set to
+ true and the mask corresponds to the non-exited threads in the
+ `thread_mask`. Otherwise, the predicate is set to false and the mask is 0.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync)
+ }];
+ string llvmBuilder = [{
+ auto intId = getMatchSyncIntrinsicId(
+ op.getVal().getType(), $kind);
+ $res = createIntrinsicCall(builder,
+ intId, {$thread_mask, $val});
+ }];
+ let assemblyFormat = "$kind $thread_mask `,` $val attr-dict `:` type($val) `->` type($res)";
+ let hasVerifier = 1;
+}
+
def NVVM_Exit : NVVM_Op<"exit"> {
let summary = "Exit Op";
let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8f080a2d597a5..ce93cb1ca4297 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1138,6 +1138,22 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
return success();
}
+LogicalResult NVVM::MatchSyncOp::verify() {
+ if (getKind() == NVVM::MatchSyncKind::all) {
+ auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+ if (!Type || Type.getBody().size() != 2 ||
+ !Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
+ return emitOpError("match.sync 'all' returns a two element struct with "
+ "first element as i32 and second element as i1");
+ }
+ } else {
+ if (!getType().isInteger(32)) {
+ return emitOpError("match.sync 'any' returns an i32");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index c3a129a82688f..74e84a8e8e9a6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -106,6 +106,23 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
llvm_unreachable("unknown shuffle kind");
}
+static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
+ NVVM::MatchSyncKind kind) {
+ switch (kind) {
+ case NVVM::MatchSyncKind::any:
+ return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
+ : llvm::Intrinsic::nvvm_match_any_sync_i64;
+ case NVVM::MatchSyncKind::all:
+ // match.all instruction has two variants -- one returns a single value,
+ // another returns a pair {value, predicate}. We currently only implement
+ // the latter as that's the variant exposed by CUDA API.
+ return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
+ : llvm::Intrinsic::nvvm_match_all_sync_i64p;
+ default:
+ llvm_unreachable("unknown match sync kind");
+ }
+}
+
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 85998d4e66254..903dbb1bb5c8e 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -550,6 +550,19 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
return
}
+// CHECK-LABEL: @match_sync
+func.func @match_sync(%val32: i32, %val64: i64, %thread_mask: i32) {
+ // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
+ %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> i32
+ // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+ %1 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
+ %2 = nvvm.match.sync any %thread_mask, %val64 : i64 -> i32
+ // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+ %3 = nvvm.match.sync all %thread_mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 4fca7fd801dbe..97fbbe2fe5fa3 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -152,3 +152,19 @@ llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
}
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_match_sync_all(%val32: i32, %thread_mask: i32) {
+ // expected-error @below {{match.sync 'all' returns a two element struct with first element as i32 and second element as i1}}
+ %0 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i8)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
+ // expected-error @below {{match.sync 'any' returns an i32}}
+ %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f39aca95b918f..c113cd2fcf5f7 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -810,3 +810,16 @@ llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
%7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
llvm.return
}
+
+// CHECK-LABEL: @nvvm_match_sync
+llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
+ // CHECK: call i32 @llvm.nvvm.match.any.sync.i32(i32 %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.match.sync any %mask, %val32 : i32 -> i32
+ // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i32p(i32 %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.match.sync all %mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: call i32 @llvm.nvvm.match.any.sync.i64(i32 %{{.*}}, i64 %{{.*}})
+ %2 = nvvm.match.sync any %mask, %val64 : i64 -> i32
+ // CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i64p(i32 %{{.*}}, i64 %{{.*}})
+ %3 = nvvm.match.sync all %mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
+ llvm.return
+}
More information about the Mlir-commits
mailing list