[Mlir-commits] [mlir] f8e59b0 - [mlir][arith] Move getNeutralElement from Linalg utils to arith

Quentin Colombet llvmlistbot at llvm.org
Tue Jul 4 05:01:20 PDT 2023


Author: Quentin Colombet
Date: 2023-07-04T13:59:48+02:00
New Revision: f8e59b09f42ccc4d386b0563fc89807e4d5b35a2

URL: https://github.com/llvm/llvm-project/commit/f8e59b09f42ccc4d386b0563fc89807e4d5b35a2
DIFF: https://github.com/llvm/llvm-project/commit/f8e59b09f42ccc4d386b0563fc89807e4d5b35a2.diff

LOG: [mlir][arith] Move getNeutralElement from Linalg utils to arith

This consolidates where this kind of implementations lives and
refactor the code to have more code sharing.

NFC

Differential Revision: https://reviews.llvm.org/D154362

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/Arith.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index e93af36d871a4d..0abafa1d4c834b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -125,6 +125,10 @@ bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
 TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                OpBuilder &builder, Location loc);
 
+/// Return the identity numeric value associated to the give op. Return
+/// std::nullopt if there is no known neutral element.
+std::optional<TypedAttr> getNeutralElement(Operation *op);
+
 /// Returns the identity value associated with an AtomicRMWKind op.
 Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
                        Location loc);

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index fb82e6b5ab71e8..b3397ae131b56f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -95,10 +95,6 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
 std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
-/// Return the identity numeric value associated to the give op. Return
-/// std::nullopt if there is no known neutral element.
-std::optional<TypedAttr> getNeutralElement(Operation *op);
-
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5e95d16b87a635..219804b005027b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -25,6 +25,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::arith;
@@ -2377,6 +2378,38 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
   return nullptr;
 }
 
+/// Return the identity numeric value associated to the give op.
+std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
+  std::optional<AtomicRMWKind> maybeKind =
+      llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
+          // Floating-point operations.
+          .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
+          .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
+          .Case([](arith::MaxFOp op) { return AtomicRMWKind::maxf; })
+          .Case([](arith::MinFOp op) { return AtomicRMWKind::minf; })
+          // Integer operations.
+          .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
+          .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
+          .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+          .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
+          .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
+          .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
+          .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
+          .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
+          .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
+          .Default([](Operation *op) { return std::nullopt; });
+  if (!maybeKind) {
+    op->emitError() << "Unknown neutral element for: " << *op;
+    return std::nullopt;
+  }
+
+  // Builder only used as helper for attribute creation.
+  OpBuilder b(op->getContext());
+  Type resultType = op->getResult(0).getType();
+
+  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc());
+}
+
 /// Returns the identity value associated with an AtomicRMWKind op.
 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
                                     OpBuilder &builder, Location loc) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 982b0243e953a7..6c859b6cb70eb5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -66,7 +66,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
 
   Operation *reductionOp = combinerOps[0];
-  std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
+  std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
   if (!identity.has_value())
     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
 
@@ -274,7 +274,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
 
   SmallVector<TypedAttr> neutralElements;
   for (Operation *reductionOp : combinerOps) {
-    std::optional<TypedAttr> neutralElement = getNeutralElement(reductionOp);
+    std::optional<TypedAttr> neutralElement =
+        arith::getNeutralElement(reductionOp);
     if (!neutralElement.has_value())
       return b.notifyMatchFailure(op, "cannot find neutral element.");
     neutralElements.push_back(*neutralElement);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 54f0bd249c3cfa..a1ab4d9fd03801 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -271,7 +271,7 @@ struct LinalgOpPartialReductionInterface
       return op->emitOpError("Failed to anaysis the reduction operation.");
 
     Operation *reductionOp = combinerOps[0];
-    std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
+    std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
     if (!identity.has_value())
       return op->emitOpError(
           "Failed to get an identity value for the reduction operation.");

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 55da7096b25c2b..fe324ab406657a 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -905,37 +905,5 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
   return reassociation;
 }
 
-/// Return the identity numeric value associated to the give op.
-std::optional<TypedAttr> getNeutralElement(Operation *op) {
-  // Builder only used as helper for attribute creation.
-  OpBuilder b(op->getContext());
-  Type resultType = op->getResult(0).getType();
-  if (auto floatType = dyn_cast<FloatType>(resultType)) {
-    const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
-    if (isa<arith::AddFOp>(op))
-      return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
-    if (isa<arith::MulFOp>(op))
-      return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
-    if (isa<arith::MaxFOp>(op))
-      return b.getFloatAttr(resultType,
-                            llvm::APFloat::getInf(semantic, /*Negative=*/true));
-    if (isa<arith::MinFOp>(op))
-      return b.getFloatAttr(
-          resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false));
-    return std::nullopt;
-  }
-  if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
-    return b.getIntegerAttr(resultType, 0);
-  if (isa<arith::AndIOp>(op))
-    return b.getIntegerAttr(resultType, -1);
-  if (isa<arith::MaxSIOp>(op))
-    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
-  if (isa<arith::MinSIOp>(op))
-    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
-  if (isa<arith::MulIOp>(op))
-    return b.getIntegerAttr(resultType, 1);
-  return std::nullopt;
-}
-
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index f281add1c28a3d..8c67b80c63f463 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -340,6 +340,7 @@ module {
     %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
     ^bb0(%in: f32, %out: f32):
       %1 = llvm.fmul %in, %in  : f32
+      // expected-error @below {{Unknown neutral element for:}}
       %2 = llvm.fadd %1, %out  : f32
       linalg.yield %2 : f32
     } -> tensor<?xf32>


        


More information about the Mlir-commits mailing list