[flang-commits] [flang] bd1d170 - [flang] Prevent a bad folding rewrite

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Sep 23 08:11:31 PDT 2022


Author: Peter Klausler
Date: 2022-09-23T08:11:16-07:00
New Revision: bd1d1701649c3d26e5b7ea788bb67f58f0779f47

URL: https://github.com/llvm/llvm-project/commit/bd1d1701649c3d26e5b7ea788bb67f58f0779f47
DIFF: https://github.com/llvm/llvm-project/commit/bd1d1701649c3d26e5b7ea788bb67f58f0779f47.diff

LOG: [flang] Prevent a bad folding rewrite

When a subexpression does not have both constant elements and
a constant shape, folding is rewriting it into a vector of its
elements.  This is of course wrong when the shape shows that
the result has rank greater than 1.

Differential Revision: https://reviews.llvm.org/D134392

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-implementation.h

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 37f4f8db996f..4bcd41409029 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -1310,15 +1310,28 @@ AsFlatArrayConstructor(const Expr<SomeKind<CAT>> &expr) {
 // into an Expr<T>, folds it, and returns the resulting wrapped
 // array constructor or constant array value.
 template <typename T>
-Expr<T> FromArrayConstructor(FoldingContext &context,
-    ArrayConstructor<T> &&values, std::optional<ConstantSubscripts> &&shape) {
-  Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
-  if (shape) {
+std::optional<Expr<T>> FromArrayConstructor(
+    FoldingContext &context, ArrayConstructor<T> &&values, const Shape &shape) {
+  if (auto constShape{AsConstantExtents(context, shape)}) {
+    Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
     if (auto *constant{UnwrapConstantValue<T>(result)}) {
-      return Expr<T>{constant->Reshape(std::move(*shape))};
+      // Elements and shape are both constant.
+      return Expr<T>{constant->Reshape(std::move(*constShape))};
+    }
+    if (constShape->size() == 1) {
+      if (auto elements{GetShape(context, result)}) {
+        if (auto constElements{AsConstantExtents(context, *elements)}) {
+          if (constElements->size() == 1 &&
+              constElements->at(0) == constShape->at(0)) {
+            // Elements are not constant, but array constructor has
+            // the right known shape and can be simply returned as is.
+            return std::move(result);
+          }
+        }
+      }
     }
   }
-  return result;
+  return std::nullopt;
 }
 
 // MapOperation is a utility for various specializations of ApplyElementwise()
@@ -1330,7 +1343,7 @@ Expr<T> FromArrayConstructor(FoldingContext &context,
 
 // Unary case
 template <typename RESULT, typename OPERAND>
-Expr<RESULT> MapOperation(FoldingContext &context,
+std::optional<Expr<RESULT>> MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f, const Shape &shape,
     Expr<OPERAND> &&values) {
   ArrayConstructor<RESULT> result{values};
@@ -1352,8 +1365,7 @@ Expr<RESULT> MapOperation(FoldingContext &context,
       result.Push(Fold(context, f(std::move(scalar))));
     }
   }
-  return FromArrayConstructor(
-      context, std::move(result), AsConstantExtents(context, shape));
+  return FromArrayConstructor(context, std::move(result), shape);
 }
 
 template <typename RESULT, typename A>
@@ -1369,10 +1381,11 @@ ArrayConstructor<RESULT> ArrayConstructorFromMold(
 
 // array * array case
 template <typename RESULT, typename LEFT, typename RIGHT>
-Expr<RESULT> MapOperation(FoldingContext &context,
+auto MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
     const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
-    Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
+    Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues)
+    -> std::optional<Expr<RESULT>> {
   auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
@@ -1404,16 +1417,16 @@ Expr<RESULT> MapOperation(FoldingContext &context,
       ++rightIter;
     }
   }
-  return FromArrayConstructor(
-      context, std::move(result), AsConstantExtents(context, shape));
+  return FromArrayConstructor(context, std::move(result), shape);
 }
 
 // array * scalar case
 template <typename RESULT, typename LEFT, typename RIGHT>
-Expr<RESULT> MapOperation(FoldingContext &context,
+auto MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
     const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
-    Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
+    Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar)
+    -> std::optional<Expr<RESULT>> {
   auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
   for (auto &leftValue : leftArrConst) {
@@ -1421,16 +1434,16 @@ Expr<RESULT> MapOperation(FoldingContext &context,
     result.Push(
         Fold(context, f(std::move(leftScalar), Expr<RIGHT>{rightScalar})));
   }
-  return FromArrayConstructor(
-      context, std::move(result), AsConstantExtents(context, shape));
+  return FromArrayConstructor(context, std::move(result), shape);
 }
 
 // scalar * array case
 template <typename RESULT, typename LEFT, typename RIGHT>
-Expr<RESULT> MapOperation(FoldingContext &context,
+auto MapOperation(FoldingContext &context,
     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
     const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
-    const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
+    const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues)
+    -> std::optional<Expr<RESULT>> {
   auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
     common::visit(
@@ -1453,8 +1466,7 @@ Expr<RESULT> MapOperation(FoldingContext &context,
           Fold(context, f(Expr<LEFT>{leftScalar}, std::move(rightScalar))));
     }
   }
-  return FromArrayConstructor(
-      context, std::move(result), AsConstantExtents(context, shape));
+  return FromArrayConstructor(context, std::move(result), shape);
 }
 
 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>


        


More information about the flang-commits mailing list