[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f32 in redux.sync Op (PR #128137)

Srinivasa Ravi llvmlistbot at llvm.org
Fri Feb 21 02:48:52 PST 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/128137

>From 9f3bf779e46d32222c43cee4f9830fee3ebb1c8b Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 14 Feb 2025 23:24:16 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for f32 in redux.sync Op

This change adds support for the f32 variants of the `redux.sync`
instruction in the NVVM Dialect through the newly added intrinsics for
the same.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 23 ++++++++--
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 16 ++++++-
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 19 ++++++++
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 43 +++++++++++++++++++
 4 files changed, 96 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..df43ed036d3a5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -257,11 +257,13 @@ def ReduxKindOr   : I32EnumAttrCase<"OR", 5, "or">;
 def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
 def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
 def ReduxKindXor  : I32EnumAttrCase<"XOR", 8, "xor">; 
+def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">;
+def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">;
 
 /// Enum attribute of the different kinds.
 def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
   [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, 
-    ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
+   ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
@@ -273,9 +275,24 @@ def NVVM_ReduxOp :
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$val,
                  ReduxKindAttr:$kind,
-                 I32:$mask_and_clamp)> {
+                 I32:$mask_and_clamp,
+                 DefaultValuedAttr<BoolAttr, "false">:$abs,
+                 DefaultValuedAttr<BoolAttr, "false">:$nan)> {
+  let summary = "Redux Sync Op";
+  let description = [{
+    `redux.sync` performs a reduction operation `kind` of the 32 bit source 
+    register across all non-exited threads in the membermask.
+
+    The `abs` and `nan` attributes can be used in the case of f32 input type, 
+    where the `abs` attribute causes the absolute value of the input to be used 
+    in the reduction operation, and the `nan` attribute causes the reduction 
+    operation to return NaN if any of the inputs to participating threads are 
+    NaN.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync)
+  }];
   string llvmBuilder = [{
-      auto intId = getReduxIntrinsicId($_resultType, $kind);
+      auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
       $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 8b13735774663..6d34cf71bb780 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -25,9 +25,17 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::detail::createIntrinsicCall;
 
+#define REDUX_F32_ID_IMPL(op, abs, hasNaN)                                     \
+  hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN                   \
+         : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
+
+#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)                                   \
+  hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
+
 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
-                                               NVVM::ReduxKind kind) {
-  if (!resultType->isIntegerTy(32))
+                                               NVVM::ReduxKind kind,
+                                               bool hasAbs, bool hasNaN) {
+  if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
     llvm_unreachable("unsupported data type for redux");
 
   switch (kind) {
@@ -47,6 +55,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
     return llvm::Intrinsic::nvvm_redux_sync_max;
   case NVVM::ReduxKind::MIN:
     return llvm::Intrinsic::nvvm_redux_sync_min;
+  case NVVM::ReduxKind::FMIN:
+    return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
+  case NVVM::ReduxKind::FMAX:
+    return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
   }
   llvm_unreachable("unknown redux kind");
 }
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index dd54acd1e317e..85998d4e66254 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
   llvm.return %r1 : i32
 }
 
+llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 {
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmin %{{.*}}
+  %r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32
+  // CHECK: nvvm.redux.sync fmax %{{.*}}
+  %r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32
+  llvm.return %r1 : f32
+}
 
 // -----
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 5ab593452ab66..d11558698d860 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_redux_sync
+llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) {
+  // CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.redux.sync add %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.redux.sync umax %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}})
+  %2 = nvvm.redux.sync umin %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}})
+  %3 = nvvm.redux.sync and %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}})
+  %4 = nvvm.redux.sync or %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}})
+  %5 = nvvm.redux.sync xor %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.redux.sync max %value, %offset: i32 -> i32
+  // CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.redux.sync min %value, %offset: i32 -> i32
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_redux_sync_f32
+llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}})
+  %2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}})
+  %3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}})
+  %4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}})
+  %5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32
+  // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
+  llvm.return
+}



More information about the Mlir-commits mailing list