[Mlir-commits] [mlir] 362240e - [mlir][Math] Support fold PowFOp with constant dense.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 6 19:16:32 PDT 2022
Author: jacquesguan
Date: 2022-07-07T10:13:08+08:00
New Revision: 362240e09e9e203b65d14b0b620803e7caa26536
URL: https://github.com/llvm/llvm-project/commit/362240e09e9e203b65d14b0b620803e7caa26536
DIFF: https://github.com/llvm/llvm-project/commit/362240e09e9e203b65d14b0b620803e7caa26536.diff
LOG: [mlir][Math] Support fold PowFOp with constant dense.
This patch adds a conditional binary constant folder which allow to exit when the constants not meet the fold condition. And use it for PowFOp to make it able to fold the constant dense.
Differential Revision: https://reviews.llvm.org/D129108
Added:
Modified:
mlir/include/mlir/Dialect/CommonFolders.h
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/test/Dialect/Math/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index d503bb02403ae..55dc5ec2349ce 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -23,12 +23,12 @@
namespace mlir {
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType,
- class CalculationT =
- function_ref<ElementValueT(ElementValueT, ElementValueT)>>
-Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+template <
+ class AttrElementT, class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT =
+ function_ref<Optional<ElementValueT>(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1])
return {};
@@ -39,9 +39,14 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
auto lhs = operands[0].cast<AttrElementT>();
auto rhs = operands[1].cast<AttrElementT>();
- return AttrElementT::get(lhs.getType(),
- calculate(lhs.getValue(), rhs.getValue()));
+ auto calRes = calculate(lhs.getValue(), rhs.getValue());
+
+ if (!calRes)
+ return {};
+
+ return AttrElementT::get(lhs.getType(), *calRes);
}
+
if (operands[0].isa<SplatElementsAttr>() &&
operands[1].isa<SplatElementsAttr>()) {
// Both operands are splats so we can avoid expanding the values out and
@@ -51,7 +56,10 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
rhs.getSplatValue<ElementValueT>());
- return DenseElementsAttr::get(lhs.getType(), elementResult);
+ if (!elementResult)
+ return {};
+
+ return DenseElementsAttr::get(lhs.getType(), *elementResult);
} else if (operands[0].isa<ElementsAttr>() &&
operands[1].isa<ElementsAttr>()) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
@@ -63,13 +71,31 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
auto rhsIt = rhs.value_begin<ElementValueT>();
SmallVector<ElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
- for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
- elementResults.push_back(calculate(*lhsIt, *rhsIt));
+ for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
+ auto elementResult = calculate(*lhsIt, *rhsIt);
+ if (!elementResult)
+ return {};
+ elementResults.push_back(*elementResult);
+ }
+
return DenseElementsAttr::get(lhs.getType(), elementResults);
}
return {};
}
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT =
+ function_ref<ElementValueT(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
+ return constFoldBinaryOpConditional<AttrElementT>(
+ operands,
+ [&](ElementValueT a, ElementValueT b) -> Optional<ElementValueT> {
+ return calculate(a, b);
+ });
+}
+
/// Performs constant folding `calculate` with element-wise behavior on the one
/// attributes in `operands` and returns the result if possible.
template <class AttrElementT,
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index f14bb003deb5d..34e20724c78a0 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -121,32 +121,18 @@ OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
- auto ft = getType().dyn_cast<FloatType>();
- if (!ft)
- return {};
-
- APFloat vals[2]{APFloat(ft.getFloatSemantics()),
- APFloat(ft.getFloatSemantics())};
- for (int i = 0; i < 2; ++i) {
- if (!operands[i])
- return {};
-
- auto attr = operands[i].dyn_cast<FloatAttr>();
- if (!attr)
- return {};
-
- vals[i] = attr.getValue();
- }
-
- if (ft.getWidth() == 64)
- return FloatAttr::get(
- getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble()));
-
- if (ft.getWidth() == 32)
- return FloatAttr::get(
- getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat()));
-
- return {};
+ return constFoldBinaryOpConditional<FloatAttr>(
+ operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
+ if (a.getSizeInBits(a.getSemantics()) == 64 &&
+ b.getSizeInBits(b.getSemantics()) == 64)
+ return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
+
+ if (a.getSizeInBits(a.getSemantics()) == 32 &&
+ b.getSizeInBits(b.getSemantics()) == 32)
+ return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
+
+ return {};
+ });
}
OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index e29941f966aab..bcfdf1b9e965c 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -83,6 +83,16 @@ func.func @powf_fold() -> f32 {
return %r : f32
}
+// CHECK-LABEL: @powf_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 4.000000e+00, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @powf_fold_vec() -> (vector<4xf32>) {
+ %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32>
+ %0 = math.powf %v1, %v2 : vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
// CHECK-LABEL: @sqrt_fold
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
// CHECK: return %[[cst]]
More information about the Mlir-commits
mailing list