[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f32 in redux.sync Op (PR #128137)
Srinivasa Ravi
llvmlistbot at llvm.org
Thu Feb 20 21:51:54 PST 2025
https://github.com/Wolfram70 created https://github.com/llvm/llvm-project/pull/128137
This change adds support for the f32 variants of the `redux.sync` instruction in the NVVM Dialect through the newly added intrinsics for the same.
>From fdbf42b9df538698b62fb8f4420228e7519821d1 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 14 Feb 2025 23:24:16 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for f32 in redux.sync Op
This change adds support for the f32 variants of the `redux.sync`
instruction in the NVVM Dialect through the newly added intrinsics for
the same.
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 10 +++--
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 17 +++++++-
mlir/test/Dialect/LLVMIR/nvvm.mlir | 19 ++++++++
mlir/test/Target/LLVMIR/nvvmir.mlir | 43 +++++++++++++++++++
4 files changed, 84 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..92bc73b0d03ff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -257,11 +257,13 @@ def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">;
def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">;
+def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">;
+def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">;
/// Enum attribute of the different kinds.
def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
[ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
- ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
+ ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -273,9 +275,11 @@ def NVVM_ReduxOp :
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$val,
ReduxKindAttr:$kind,
- I32:$mask_and_clamp)> {
+ I32:$mask_and_clamp,
+ DefaultValuedAttr<BoolAttr, "false">:$abs,
+ DefaultValuedAttr<BoolAttr, "false">:$nan)> {
string llvmBuilder = [{
- auto intId = getReduxIntrinsicId($_resultType, $kind);
+ auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
$res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 8b13735774663..721778be6ba20 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -25,9 +25,18 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;
+#define REDUX_F32_ID_IMPL(op, abs, nan) \
+ hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##nan \
+ : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
+
+#define GET_REDUX_F32_ID(op, abs, nan) \
+ hasAbs ? REDUX_F32_ID_IMPL(op, abs, nan) : REDUX_F32_ID_IMPL(op, , nan)
+
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
- NVVM::ReduxKind kind) {
- if (!resultType->isIntegerTy(32))
+ NVVM::ReduxKind kind,
+ bool hasAbs,
+ bool hasNaN) {
+ if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
llvm_unreachable("unsupported data type for redux");
switch (kind) {
@@ -47,6 +56,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
return llvm::Intrinsic::nvvm_redux_sync_max;
case NVVM::ReduxKind::MIN:
return llvm::Intrinsic::nvvm_redux_sync_min;
+ case NVVM::ReduxKind::FMIN:
+ return GET_REDUX_F32_ID(min, _abs, _NaN);
+ case NVVM::ReduxKind::FMAX:
+ return GET_REDUX_F32_ID(max, _abs, _NaN);
}
llvm_unreachable("unknown redux kind");
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index dd54acd1e317e..85998d4e66254 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
llvm.return %r1 : i32
}
+llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 {
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmin %{{.*}}
+ %r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32
+ // CHECK: nvvm.redux.sync fmax %{{.*}}
+ %r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32
+ llvm.return %r1 : f32
+}
// -----
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 5ab593452ab66..d11558698d860 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_redux_sync
+llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) {
+ // CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.redux.sync add %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.redux.sync umax %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}})
+ %2 = nvvm.redux.sync umin %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}})
+ %3 = nvvm.redux.sync and %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}})
+ %4 = nvvm.redux.sync or %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}})
+ %5 = nvvm.redux.sync xor %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.redux.sync max %value, %offset: i32 -> i32
+ // CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.redux.sync min %value, %offset: i32 -> i32
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_redux_sync_f32
+llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}})
+ %2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}})
+ %3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}})
+ %4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}})
+ %5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32
+ // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
+ llvm.return
+}
More information about the Mlir-commits
mailing list