[Mlir-commits] [mlir] f8e59b0 - [mlir][arith] Move getNeutralElement from Linalg utils to arith
Quentin Colombet
llvmlistbot at llvm.org
Tue Jul 4 05:01:20 PDT 2023
Author: Quentin Colombet
Date: 2023-07-04T13:59:48+02:00
New Revision: f8e59b09f42ccc4d386b0563fc89807e4d5b35a2
URL: https://github.com/llvm/llvm-project/commit/f8e59b09f42ccc4d386b0563fc89807e4d5b35a2
DIFF: https://github.com/llvm/llvm-project/commit/f8e59b09f42ccc4d386b0563fc89807e4d5b35a2.diff
LOG: [mlir][arith] Move getNeutralElement from Linalg utils to arith
This consolidates where this kind of implementations lives and
refactor the code to have more code sharing.
NFC
Differential Revision: https://reviews.llvm.org/D154362
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/Arith.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index e93af36d871a4d..0abafa1d4c834b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -125,6 +125,10 @@ bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc);
+/// Return the identity numeric value associated to the give op. Return
+/// std::nullopt if there is no known neutral element.
+std::optional<TypedAttr> getNeutralElement(Operation *op);
+
/// Returns the identity value associated with an AtomicRMWKind op.
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
Location loc);
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index fb82e6b5ab71e8..b3397ae131b56f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -95,10 +95,6 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
-/// Return the identity numeric value associated to the give op. Return
-/// std::nullopt if there is no known neutral element.
-std::optional<TypedAttr> getNeutralElement(Operation *op);
-
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5e95d16b87a635..219804b005027b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -25,6 +25,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arith;
@@ -2377,6 +2378,38 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return nullptr;
}
+/// Return the identity numeric value associated to the give op.
+std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
+ std::optional<AtomicRMWKind> maybeKind =
+ llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
+ // Floating-point operations.
+ .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
+ .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
+ .Case([](arith::MaxFOp op) { return AtomicRMWKind::maxf; })
+ .Case([](arith::MinFOp op) { return AtomicRMWKind::minf; })
+ // Integer operations.
+ .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
+ .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
+ .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
+ .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
+ .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
+ .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
+ .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
+ .Default([](Operation *op) { return std::nullopt; });
+ if (!maybeKind) {
+ op->emitError() << "Unknown neutral element for: " << *op;
+ return std::nullopt;
+ }
+
+ // Builder only used as helper for attribute creation.
+ OpBuilder b(op->getContext());
+ Type resultType = op->getResult(0).getType();
+
+ return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc());
+}
+
/// Returns the identity value associated with an AtomicRMWKind op.
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
OpBuilder &builder, Location loc) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 982b0243e953a7..6c859b6cb70eb5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -66,7 +66,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
Operation *reductionOp = combinerOps[0];
- std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
+ std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
if (!identity.has_value())
return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
@@ -274,7 +274,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
SmallVector<TypedAttr> neutralElements;
for (Operation *reductionOp : combinerOps) {
- std::optional<TypedAttr> neutralElement = getNeutralElement(reductionOp);
+ std::optional<TypedAttr> neutralElement =
+ arith::getNeutralElement(reductionOp);
if (!neutralElement.has_value())
return b.notifyMatchFailure(op, "cannot find neutral element.");
neutralElements.push_back(*neutralElement);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 54f0bd249c3cfa..a1ab4d9fd03801 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -271,7 +271,7 @@ struct LinalgOpPartialReductionInterface
return op->emitOpError("Failed to anaysis the reduction operation.");
Operation *reductionOp = combinerOps[0];
- std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
+ std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
if (!identity.has_value())
return op->emitOpError(
"Failed to get an identity value for the reduction operation.");
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 55da7096b25c2b..fe324ab406657a 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -905,37 +905,5 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
return reassociation;
}
-/// Return the identity numeric value associated to the give op.
-std::optional<TypedAttr> getNeutralElement(Operation *op) {
- // Builder only used as helper for attribute creation.
- OpBuilder b(op->getContext());
- Type resultType = op->getResult(0).getType();
- if (auto floatType = dyn_cast<FloatType>(resultType)) {
- const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
- if (isa<arith::AddFOp>(op))
- return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
- if (isa<arith::MulFOp>(op))
- return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
- if (isa<arith::MaxFOp>(op))
- return b.getFloatAttr(resultType,
- llvm::APFloat::getInf(semantic, /*Negative=*/true));
- if (isa<arith::MinFOp>(op))
- return b.getFloatAttr(
- resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false));
- return std::nullopt;
- }
- if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
- return b.getIntegerAttr(resultType, 0);
- if (isa<arith::AndIOp>(op))
- return b.getIntegerAttr(resultType, -1);
- if (isa<arith::MaxSIOp>(op))
- return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
- if (isa<arith::MinSIOp>(op))
- return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
- if (isa<arith::MulIOp>(op))
- return b.getIntegerAttr(resultType, 1);
- return std::nullopt;
-}
-
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index f281add1c28a3d..8c67b80c63f463 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -340,6 +340,7 @@ module {
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%1 = llvm.fmul %in, %in : f32
+ // expected-error @below {{Unknown neutral element for:}}
%2 = llvm.fadd %1, %out : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
More information about the Mlir-commits
mailing list