[Mlir-commits] [mlir] 088d388 - [mlir][Arithmetic] Add constant folder for negf.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 8 00:36:52 PDT 2022
Author: jacquesguan
Date: 2022-04-08T07:36:29Z
New Revision: 088d38890ccee92d5ef6ae13ec1c50f9b0083866
URL: https://github.com/llvm/llvm-project/commit/088d38890ccee92d5ef6ae13ec1c50f9b0083866
DIFF: https://github.com/llvm/llvm-project/commit/088d38890ccee92d5ef6ae13ec1c50f9b0083866.diff
LOG: [mlir][Arithmetic] Add constant folder for negf.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D123293
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/Dialect/CommonFolders.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index f7fd294be31ff..6305c6947c753 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -585,6 +585,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
%x = arith.negf %y : tensor<4x?xf8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index b3f600ed4ff49..7ba43c92e7563 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -69,6 +69,45 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
}
return {};
}
+
+/// Performs constant folding `calculate` with element-wise behavior on the one
+/// attributes in `operands` and returns the result if possible.
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT = function_ref<ElementValueT(ElementValueT)>>
+Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &&calculate) {
+ assert(operands.size() == 1 && "unary op takes one operands");
+ if (!operands[0])
+ return {};
+
+ if (operands[0].isa<AttrElementT>()) {
+ auto op = operands[0].cast<AttrElementT>();
+
+ return AttrElementT::get(op.getType(), calculate(op.getValue()));
+ }
+ if (operands[0].isa<SplatElementsAttr>()) {
+ // Both operands are splats so we can avoid expanding the values out and
+ // just fold based on the splat value.
+ auto op = operands[0].cast<SplatElementsAttr>();
+
+ auto elementResult = calculate(op.getSplatValue<ElementValueT>());
+ return DenseElementsAttr::get(op.getType(), elementResult);
+ } else if (operands[0].isa<ElementsAttr>()) {
+ // Operands are ElementsAttr-derived; perform an element-wise fold by
+ // expanding the values.
+ auto op = operands[0].cast<ElementsAttr>();
+
+ auto opIt = op.value_begin<ElementValueT>();
+ SmallVector<ElementValueT> elementResults;
+ elementResults.reserve(op.getNumElements());
+ for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt)
+ elementResults.push_back(calculate(*opIt));
+ return DenseElementsAttr::get(op.getType(), elementResults);
+ }
+ return {};
+}
+
} // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 74b77dc9cbd63..d38f4f1c9994e 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -573,6 +573,15 @@ void arith::XOrIOp::getCanonicalizationPatterns(
patterns.add<XOrINotCmpI>(context);
}
+//===----------------------------------------------------------------------===//
+// NegFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
+ return constFoldUnaryOp<FloatAttr>(operands,
+ [](const APFloat &a) { return -a; });
+}
+
//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 550e84c97118d..b4c92d6089e9b 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -1166,3 +1166,14 @@ func @nofoldShrs2() -> i64 {
%r = arith.shrsi %c1, %cm32 : i64
return %r : i64
}
+
+// -----
+
+// CHECK-LABEL: @test_negf(
+// CHECK: %[[res:.+]] = arith.constant -2.0
+// CHECK: return %[[res]]
+func @test_negf() -> (f32) {
+ %c = arith.constant 2.0 : f32
+ %0 = arith.negf %c : f32
+ return %0: f32
+}
More information about the Mlir-commits
mailing list