[Mlir-commits] [mlir] [mlir] Add `maxnumf` and `minnumf` to `AtomicRMWKind` (PR #66442)

Daniil Dudkin llvmlistbot at llvm.org
Thu Sep 14 15:31:57 PDT 2023


https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66442:

This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.


>From 3d680710f6ea0680895024337820a367d4f38fe8 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Fri, 15 Sep 2023 01:00:22 +0300
Subject: [PATCH] [mlir] Add `maxnumf` and `minnumf` to `AtomicRMWKind`

This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
---
 .../include/mlir/Dialect/Arith/IR/ArithBase.td |  4 +++-
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp         |  4 ++++
 .../Dialect/MemRef/Transforms/ExpandOps.cpp    |  7 +++++--
 mlir/test/Dialect/MemRef/expand-ops.mlir       | 18 +++++++++++++++---
 4 files changed, 27 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index a833e9c8220af5b..133af893e4efa74 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF     : I64EnumAttrCase<"mulf", 9>;
 def ATOMIC_RMW_KIND_MULI     : I64EnumAttrCase<"muli", 10>;
 def ATOMIC_RMW_KIND_ORI      : I64EnumAttrCase<"ori", 11>;
 def ATOMIC_RMW_KIND_ANDI     : I64EnumAttrCase<"andi", 12>;
+def ATOMIC_RMW_KIND_MAXNUMF  : I64EnumAttrCase<"maxnumf", 13>;
+def ATOMIC_RMW_KIND_MINNUMF  : I64EnumAttrCase<"minnumf", 14>;
 
 def AtomicRMWKindAttr : I64EnumAttr<
     "AtomicRMWKind", "",
@@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr<
      ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
      ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
      ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
-     ATOMIC_RMW_KIND_ANDI]> {
+     ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
   let cppNamespace = "::mlir::arith";
 }
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d39c5b6051122e4..ae8a6ef350ce191 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
+   case AtomicRMWKind::maxnumf:
+    return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
+  case AtomicRMWKind::minnumf:
+    return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::maxs:
     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
   case AtomicRMWKind::mins:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index b3beaada2539dbc..faba12f5bf82f89 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace memref {
@@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
         [](memref::AtomicRMWOp op) {
-          return op.getKind() != arith::AtomicRMWKind::maximumf &&
-                 op.getKind() != arith::AtomicRMWKind::minimumf;
+          constexpr std::array shouldBeExpandedKinds = {
+              arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
+              arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
+          return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
         });
     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
       return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 6c98cf978505334..f958a92b751a4ab 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -3,9 +3,11 @@
 // CHECK-LABEL: func @atomic_rmw_to_generic
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
-  %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
-  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
-  return %x : f32
+  %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  return %a : f32
 }
 // CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
@@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32
 // CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
 // CHECK:   memref.atomic_yield [[MINIMUM]] : f32
 // CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXNUM]] : f32
+// CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINNUM]] : f32
+// CHECK: }
 // CHECK: return [[RESULT]] : f32
 
 // -----



More information about the Mlir-commits mailing list