[flang-commits] [flang] bd1d170 - [flang] Prevent a bad folding rewrite
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Fri Sep 23 08:11:31 PDT 2022
Author: Peter Klausler
Date: 2022-09-23T08:11:16-07:00
New Revision: bd1d1701649c3d26e5b7ea788bb67f58f0779f47
URL: https://github.com/llvm/llvm-project/commit/bd1d1701649c3d26e5b7ea788bb67f58f0779f47
DIFF: https://github.com/llvm/llvm-project/commit/bd1d1701649c3d26e5b7ea788bb67f58f0779f47.diff
LOG: [flang] Prevent a bad folding rewrite
When a subexpression does not have both constant elements and
a constant shape, folding is rewriting it into a vector of its
elements. This is of course wrong when the shape shows that
the result has rank greater than 1.
Differential Revision: https://reviews.llvm.org/D134392
Added:
Modified:
flang/lib/Evaluate/fold-implementation.h
Removed:
################################################################################
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 37f4f8db996f..4bcd41409029 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -1310,15 +1310,28 @@ AsFlatArrayConstructor(const Expr<SomeKind<CAT>> &expr) {
// into an Expr<T>, folds it, and returns the resulting wrapped
// array constructor or constant array value.
template <typename T>
-Expr<T> FromArrayConstructor(FoldingContext &context,
- ArrayConstructor<T> &&values, std::optional<ConstantSubscripts> &&shape) {
- Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
- if (shape) {
+std::optional<Expr<T>> FromArrayConstructor(
+ FoldingContext &context, ArrayConstructor<T> &&values, const Shape &shape) {
+ if (auto constShape{AsConstantExtents(context, shape)}) {
+ Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
if (auto *constant{UnwrapConstantValue<T>(result)}) {
- return Expr<T>{constant->Reshape(std::move(*shape))};
+ // Elements and shape are both constant.
+ return Expr<T>{constant->Reshape(std::move(*constShape))};
+ }
+ if (constShape->size() == 1) {
+ if (auto elements{GetShape(context, result)}) {
+ if (auto constElements{AsConstantExtents(context, *elements)}) {
+ if (constElements->size() == 1 &&
+ constElements->at(0) == constShape->at(0)) {
+ // Elements are not constant, but array constructor has
+ // the right known shape and can be simply returned as is.
+ return std::move(result);
+ }
+ }
+ }
}
}
- return result;
+ return std::nullopt;
}
// MapOperation is a utility for various specializations of ApplyElementwise()
@@ -1330,7 +1343,7 @@ Expr<T> FromArrayConstructor(FoldingContext &context,
// Unary case
template <typename RESULT, typename OPERAND>
-Expr<RESULT> MapOperation(FoldingContext &context,
+std::optional<Expr<RESULT>> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f, const Shape &shape,
Expr<OPERAND> &&values) {
ArrayConstructor<RESULT> result{values};
@@ -1352,8 +1365,7 @@ Expr<RESULT> MapOperation(FoldingContext &context,
result.Push(Fold(context, f(std::move(scalar))));
}
}
- return FromArrayConstructor(
- context, std::move(result), AsConstantExtents(context, shape));
+ return FromArrayConstructor(context, std::move(result), shape);
}
template <typename RESULT, typename A>
@@ -1369,10 +1381,11 @@ ArrayConstructor<RESULT> ArrayConstructorFromMold(
// array * array case
template <typename RESULT, typename LEFT, typename RIGHT>
-Expr<RESULT> MapOperation(FoldingContext &context,
+auto MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
- Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
+ Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues)
+ -> std::optional<Expr<RESULT>> {
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
@@ -1404,16 +1417,16 @@ Expr<RESULT> MapOperation(FoldingContext &context,
++rightIter;
}
}
- return FromArrayConstructor(
- context, std::move(result), AsConstantExtents(context, shape));
+ return FromArrayConstructor(context, std::move(result), shape);
}
// array * scalar case
template <typename RESULT, typename LEFT, typename RIGHT>
-Expr<RESULT> MapOperation(FoldingContext &context,
+auto MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
- Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
+ Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar)
+ -> std::optional<Expr<RESULT>> {
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
for (auto &leftValue : leftArrConst) {
@@ -1421,16 +1434,16 @@ Expr<RESULT> MapOperation(FoldingContext &context,
result.Push(
Fold(context, f(std::move(leftScalar), Expr<RIGHT>{rightScalar})));
}
- return FromArrayConstructor(
- context, std::move(result), AsConstantExtents(context, shape));
+ return FromArrayConstructor(context, std::move(result), shape);
}
// scalar * array case
template <typename RESULT, typename LEFT, typename RIGHT>
-Expr<RESULT> MapOperation(FoldingContext &context,
+auto MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
- const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
+ const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues)
+ -> std::optional<Expr<RESULT>> {
auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
common::visit(
@@ -1453,8 +1466,7 @@ Expr<RESULT> MapOperation(FoldingContext &context,
Fold(context, f(Expr<LEFT>{leftScalar}, std::move(rightScalar))));
}
}
- return FromArrayConstructor(
- context, std::move(result), AsConstantExtents(context, shape));
+ return FromArrayConstructor(context, std::move(result), shape);
}
template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
More information about the flang-commits
mailing list