[Mlir-commits] [mlir] [mlir][arith] Rename `AtomicRMWKind`'s `maxf` → `maximumf`, `minf` → `minimumf` (PR #66135)

Daniil Dudkin llvmlistbot at llvm.org
Tue Sep 12 13:03:03 PDT 2023


https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66135:

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit renames `maxf` and `minf` enumerators of `AtomicRMWKind`
to better reflect the current naming scheme and the goals of the RFC.


>From 536eeb936923a35023107b8a894f5320be5b137e Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Tue, 12 Sep 2023 22:58:32 +0300
Subject: [PATCH] =?UTF-8?q?[mlir][arith]=20Rename=20`AtomicRMWKind`'s=20`m?=
 =?UTF-8?q?axf`=20=E2=86=92=20`maximumf`,=20`minf`=20=E2=86=92=20`minimumf?=
 =?UTF-8?q?`?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

This commit renames `maxf` and `minf` enumerators of `AtomicRMWKind`
to better reflect the current naming scheme and the goals of the RFC.
---
 .../mlir/Dialect/Arith/IR/ArithBase.td        | 30 +++++++++----------
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  4 +--
 .../Affine/Analysis/AffineAnalysis.cpp        |  6 ++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 12 ++++----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  6 ++--
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  4 +--
 .../Dialect/MemRef/Transforms/ExpandOps.cpp   | 10 +++----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  4 +--
 mlir/test/Dialect/Affine/invalid.mlir         |  2 +-
 mlir/test/Dialect/Affine/ops.mlir             |  4 +--
 mlir/test/Dialect/MemRef/expand-ops.mlir      |  2 +-
 11 files changed, 43 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 78fd7bdf012f8a8..a833e9c8220af5b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -69,25 +69,25 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
   let cppNamespace = "::mlir::arith";
 }
 
-def ATOMIC_RMW_KIND_ADDF    : I64EnumAttrCase<"addf", 0>;
-def ATOMIC_RMW_KIND_ADDI    : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN  : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXF    : I64EnumAttrCase<"maxf", 3>;
-def ATOMIC_RMW_KIND_MAXS    : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU    : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINF    : I64EnumAttrCase<"minf", 6>;
-def ATOMIC_RMW_KIND_MINS    : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU    : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF    : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI    : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI     : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI    : I64EnumAttrCase<"andi", 12>;
+def ATOMIC_RMW_KIND_ADDF     : I64EnumAttrCase<"addf", 0>;
+def ATOMIC_RMW_KIND_ADDI     : I64EnumAttrCase<"addi", 1>;
+def ATOMIC_RMW_KIND_ASSIGN   : I64EnumAttrCase<"assign", 2>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
+def ATOMIC_RMW_KIND_MAXS     : I64EnumAttrCase<"maxs", 4>;
+def ATOMIC_RMW_KIND_MAXU     : I64EnumAttrCase<"maxu", 5>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
+def ATOMIC_RMW_KIND_MINS     : I64EnumAttrCase<"mins", 7>;
+def ATOMIC_RMW_KIND_MINU     : I64EnumAttrCase<"minu", 8>;
+def ATOMIC_RMW_KIND_MULF     : I64EnumAttrCase<"mulf", 9>;
+def ATOMIC_RMW_KIND_MULI     : I64EnumAttrCase<"muli", 10>;
+def ATOMIC_RMW_KIND_ORI      : I64EnumAttrCase<"ori", 11>;
+def ATOMIC_RMW_KIND_ANDI     : I64EnumAttrCase<"andi", 12>;
 
 def AtomicRMWKindAttr : I64EnumAttr<
     "AtomicRMWKind", "",
     [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
-     ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
-     ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+     ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
+     ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
      ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
      ATOMIC_RMW_KIND_ANDI]> {
   let cppNamespace = "::mlir::arith";
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 516a6b8ed88e61e..97faefe2cd4d631 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1594,13 +1594,13 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
     return LLVM::AtomicBinOp::add;
   case arith::AtomicRMWKind::assign:
     return LLVM::AtomicBinOp::xchg;
-  case arith::AtomicRMWKind::maxf:
+  case arith::AtomicRMWKind::maximumf:
     return LLVM::AtomicBinOp::fmax;
   case arith::AtomicRMWKind::maxs:
     return LLVM::AtomicBinOp::max;
   case arith::AtomicRMWKind::maxu:
     return LLVM::AtomicBinOp::umax;
-  case arith::AtomicRMWKind::minf:
+  case arith::AtomicRMWKind::minimumf:
     return LLVM::AtomicBinOp::fmin;
   case arith::AtomicRMWKind::mins:
     return LLVM::AtomicBinOp::min;
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index ab1dfbdb419b891..1ba0bc8b6bfbe5e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -60,8 +60,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
           .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
           .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
           .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
-          .Case([](arith::MinimumFOp) { return arith::AtomicRMWKind::minf; })
-          .Case([](arith::MaximumFOp) { return arith::AtomicRMWKind::maxf; })
+          .Case(
+              [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
+          .Case(
+              [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
           .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
           .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
           .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 77bf8a438d6db84..1e34ac598860f52 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2369,7 +2369,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                             OpBuilder &builder, Location loc,
                                             bool useOnlyFiniteValue) {
   switch (kind) {
-  case AtomicRMWKind::maxf: {
+  case AtomicRMWKind::maximumf: {
     const llvm::fltSemantics &semantic =
         llvm::cast<FloatType>(resultType).getFloatSemantics();
     APFloat identity = useOnlyFiniteValue
@@ -2390,7 +2390,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
     return builder.getIntegerAttr(
         resultType, APInt::getSignedMinValue(
                         llvm::cast<IntegerType>(resultType).getWidth()));
-  case AtomicRMWKind::minf: {
+  case AtomicRMWKind::minimumf: {
     const llvm::fltSemantics &semantic =
         llvm::cast<FloatType>(resultType).getFloatSemantics();
     APFloat identity = useOnlyFiniteValue
@@ -2426,8 +2426,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
           // Floating-point operations.
           .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
           .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
-          .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maxf; })
-          .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minf; })
+          .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
+          .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
           // Integer operations.
           .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
           .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
@@ -2482,9 +2482,9 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<arith::MulFOp>(loc, lhs, rhs);
   case AtomicRMWKind::muli:
     return builder.create<arith::MulIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::maxf:
+  case AtomicRMWKind::maximumf:
     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
-  case AtomicRMWKind::minf:
+  case AtomicRMWKind::minimumf:
     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::maxs:
     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d26e68cb47ac1e0..f87aa4559e10afe 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2549,9 +2549,9 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   dims.erase(dims.begin() + reductionDim);
   // Step 1: Compute max along dim.
   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF =
-      arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc,
-                              /*useOnlyFiniteValue=*/true);
+  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
+                                                 elementType, b, loc,
+                                                 /*useOnlyFiniteValue=*/true);
   Value neutralForMaxFInit =
       b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
           .result();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 42da47a5381e789..215a8f5e7d18be0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3402,8 +3402,8 @@ LogicalResult AtomicRMWOp::verify() {
         "expects the number of subscripts to be equal to memref rank");
   switch (getKind()) {
   case arith::AtomicRMWKind::addf:
-  case arith::AtomicRMWKind::maxf:
-  case arith::AtomicRMWKind::minf:
+  case arith::AtomicRMWKind::maximumf:
+  case arith::AtomicRMWKind::minimumf:
   case arith::AtomicRMWKind::mulf:
     if (!llvm::isa<FloatType>(getValue().getType()))
       return emitOpError() << "with kind '"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 8a276ebbff6a921..7c3ca19b789c750 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -36,7 +36,7 @@ namespace {
 /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
 /// `memref.generic_atomic_rmw` with the expanded code.
 ///
-/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
 ///
 /// will be lowered to
 ///
@@ -54,10 +54,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
                                 PatternRewriter &rewriter) const final {
     arith::CmpFPredicate predicate;
     switch (op.getKind()) {
-    case arith::AtomicRMWKind::maxf:
+    case arith::AtomicRMWKind::maximumf:
       predicate = arith::CmpFPredicate::OGT;
       break;
-    case arith::AtomicRMWKind::minf:
+    case arith::AtomicRMWKind::minimumf:
       predicate = arith::CmpFPredicate::OLT;
       break;
     default:
@@ -137,8 +137,8 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
         [](memref::AtomicRMWOp op) {
-          return op.getKind() != arith::AtomicRMWKind::maxf &&
-                 op.getKind() != arith::AtomicRMWKind::minf;
+          return op.getKind() != arith::AtomicRMWKind::maximumf &&
+                 op.getKind() != arith::AtomicRMWKind::minimumf;
         });
     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
       return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9422936bf21e357..11aa76798bcaae1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -493,7 +493,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
   case arith::AtomicRMWKind::muli:
     return builder.create<vector::ReductionOp>(vector.getLoc(),
                                                CombiningKind::MUL, vector);
-  case arith::AtomicRMWKind::minf:
+  case arith::AtomicRMWKind::minimumf:
     return builder.create<vector::ReductionOp>(vector.getLoc(),
                                                CombiningKind::MINF, vector);
   case arith::AtomicRMWKind::mins:
@@ -502,7 +502,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
   case arith::AtomicRMWKind::minu:
     return builder.create<vector::ReductionOp>(vector.getLoc(),
                                                CombiningKind::MINUI, vector);
-  case arith::AtomicRMWKind::maxf:
+  case arith::AtomicRMWKind::maximumf:
     return builder.create<vector::ReductionOp>(vector.getLoc(),
                                                CombiningKind::MAXF, vector);
   case arith::AtomicRMWKind::maxs:
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index fd7d7df9d8735f5..1dc3451ed7db87c 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -287,7 +287,7 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = memref.alloc() : memref<100x100xi32>
-  %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minf") -> (f32) {
+  %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minimumf") -> (f32) {
     %2 = affine.load %0[%i, %j] : memref<100x100xi32>
     //  expected-error at +1 {{types mismatch between yield op and its parent}}
     affine.yield %2 : i32
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index f55d59a3e64707b..1063f2a7ecba489 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -158,8 +158,8 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
 func.func @parallel(%A : memref<100x100xf32>, %N : index) {
   // CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10)
   affine.parallel (%i0, %j0) = (0, 0) to (symbol(%N), 100) step (10, 10) {
-    // CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minf", "maxf") -> (f32, f32)
-    %0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minf", "maxf") -> (f32, f32) {
+    // CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minimumf", "maximumf") -> (f32, f32)
+    %0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minimumf", "maximumf") -> (f32, f32) {
       %2 = affine.load %A[%i0 + %i0, %j0 + %j1] : memref<100x100xf32>
       affine.yield %2, %2 : f32, f32
     }
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index a0d8e52d6e7e275..3234b35e99dcdfe 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: func @atomic_rmw_to_generic
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
-  %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
   return %x : f32
 }
 // CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {



More information about the Mlir-commits mailing list