[flang-commits] [flang] 5e5176a - [flang] Don't fold operation when shapes differ

Leandro Lupori via flang-commits flang-commits at lists.llvm.org
Fri Mar 31 05:13:36 PDT 2023


Author: Leandro Lupori
Date: 2023-03-31T09:13:14-03:00
New Revision: 5e5176adb1834b215412a90acadfa0d14c18de8f

URL: https://github.com/llvm/llvm-project/commit/5e5176adb1834b215412a90acadfa0d14c18de8f
DIFF: https://github.com/llvm/llvm-project/commit/5e5176adb1834b215412a90acadfa0d14c18de8f.diff

LOG: [flang] Don't fold operation when shapes differ

When folding a binary operation between two array constructors, it
is necessary to check if each value contained in the left operand
has the same rank and shape as the one on the right.
Otherwise, lowering would end up with an operation between values
of different ranks/shapes, which could result in a crash.

For instance, the following code was crashing the compiler:
integer :: x(4), y(2, 2), z(4)

z = (/x/) + (/y/)

Fixes #60229

Reviewed By: klausler, jeanPerier

Differential Revision: https://reviews.llvm.org/D147181

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-implementation.h
    flang/test/Evaluate/rewrite01.f90
    flang/test/Lower/array-expression.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index bff0f52f0881f..16924e9bd0c01 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -1381,6 +1381,28 @@ ArrayConstructor<RESULT> ArrayConstructorFromMold(
   return result;
 }
 
+template <typename LEFT, typename RIGHT>
+bool ShapesMatch(FoldingContext &context,
+    const ArrayConstructor<LEFT> &leftArrConst,
+    const ArrayConstructor<RIGHT> &rightArrConst) {
+  auto rightIter{rightArrConst.begin()};
+  for (auto &leftValue : leftArrConst) {
+    CHECK(rightIter != rightArrConst.end());
+    auto &leftExpr{std::get<Expr<LEFT>>(leftValue.u)};
+    auto &rightExpr{std::get<Expr<RIGHT>>(rightIter->u)};
+    if (leftExpr.Rank() != rightExpr.Rank()) {
+      return false;
+    }
+    std::optional<Shape> leftShape{GetShape(context, leftExpr)};
+    std::optional<Shape> rightShape{GetShape(context, rightExpr)};
+    if (!leftShape || !rightShape || *leftShape != *rightShape) {
+      return false;
+    }
+    ++rightIter;
+  }
+  return true;
+}
+
 // array * array case
 template <typename RESULT, typename LEFT, typename RIGHT>
 auto MapOperation(FoldingContext &context,
@@ -1391,11 +1413,14 @@ auto MapOperation(FoldingContext &context,
   auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
-    common::visit(
-        [&](auto &&kindExpr) {
+    bool mapped{common::visit(
+        [&](auto &&kindExpr) -> bool {
           using kindType = ResultType<decltype(kindExpr)>;
 
           auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
+          if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
+            return false;
+          }
           auto rightIter{rightArrConst.begin()};
           for (auto &leftValue : leftArrConst) {
             CHECK(rightIter != rightArrConst.end());
@@ -1405,10 +1430,17 @@ auto MapOperation(FoldingContext &context,
                 f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
             ++rightIter;
           }
+          return true;
         },
-        std::move(rightValues.u));
+        std::move(rightValues.u))};
+    if (!mapped) {
+      return std::nullopt;
+    }
   } else {
     auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
+    if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
+      return std::nullopt;
+    }
     auto rightIter{rightArrConst.begin()};
     for (auto &leftValue : leftArrConst) {
       CHECK(rightIter != rightArrConst.end());

diff  --git a/flang/test/Evaluate/rewrite01.f90 b/flang/test/Evaluate/rewrite01.f90
index 3f7f954cc69f5..cbcbfb97189c7 100644
--- a/flang/test/Evaluate/rewrite01.f90
+++ b/flang/test/Evaluate/rewrite01.f90
@@ -196,7 +196,9 @@ subroutine may_change_p_bounds(p)
 end subroutine
 
 !CHECK-LABEL: array_constructor
-subroutine array_constructor()
+subroutine array_constructor(a, u, v, w, x, y, z)
+  real :: a(4)
+  integer :: u(:), v(1), w(2), x(4), y(4), z(2, 2)
   interface
     function return_allocatable()
      real, allocatable :: return_allocatable(:)
@@ -204,6 +206,28 @@ function return_allocatable()
   end interface
   !CHECK: PRINT *, size([REAL(4)::return_allocatable(),return_allocatable()])
   print *, size([return_allocatable(), return_allocatable()])
+  !CHECK: PRINT *, [INTEGER(4)::x+y]
+  print *, (/x/) + (/y/)
+  !CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::z]
+  print *, (/x/) + (/z/)
+  !CHECK: PRINT *, [INTEGER(4)::x+y,x+y]
+  print *, (/x, x/) + (/y, y/)
+  !CHECK: PRINT *, [INTEGER(4)::x,x]+[INTEGER(4)::x,z]
+  print *, (/x, x/) + (/x, z/)
+  !CHECK: PRINT *, [INTEGER(4)::x,w,w]+[INTEGER(4)::w,w,x]
+  print *, (/x, w, w/) + (/w, w, x/)
+  !CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::1_4,2_4,3_4,4_4]
+  print *, (/x/) + (/1, 2, 3, 4/)
+  !CHECK: PRINT *, [INTEGER(4)::v]+[INTEGER(4)::1_4]
+  print *, (/v/) + (/1/)
+  !CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::u]
+  print *, (/x/) + (/u/)
+  !CHECK: PRINT *, [INTEGER(4)::u]+[INTEGER(4)::u]
+  print *, (/u/) + (/u/)
+  !CHECK: PRINT *, [REAL(4)::a**x]
+  print *, (/a/) ** (/x/)
+  !CHECK: PRINT *, [REAL(4)::a]**[INTEGER(4)::z]
+  print *, (/a/) ** (/z/)
 end subroutine
 
 !CHECK-LABEL: array_ctor_implied_do_index

diff  --git a/flang/test/Lower/array-expression.f90 b/flang/test/Lower/array-expression.f90
index a373548f0232a..820ca42e95acf 100644
--- a/flang/test/Lower/array-expression.f90
+++ b/flang/test/Lower/array-expression.f90
@@ -1158,4 +1158,108 @@ subroutine test_elemental_character_intrinsic(c1, c2)
   print *, scan(c1, c2)
 end subroutine
 
+! Check that the expression is folded, with the first operation being an add
+! between x and y, resulting in a new temporary array.
+!
+! CHECK-LABEL: func @_QPtest20a(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG2:.*]]: !fir.ref<!fir.array<4xi32>>
+! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32>
+! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) {
+! CHECK:   %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
+! CHECK:   %[[YI:.*]] = fir.array_fetch %[[Y]], %[[I]] : (!fir.array<4xi32>, index) -> i32
+! CHECK:   %[[ADD:.*]] = arith.addi %[[XI]], %[[YI]] : i32
+! CHECK:   {{.*}} = fir.array_update %[[TEMP3]], %[[ADD]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
+! CHECK: }
+subroutine test20a(x, y, z)
+  integer :: x(4), y(4), z(4)
+
+  z = (/x/) + (/y/)
+end subroutine
+
+! Check that the expression is not folded, with the first operations being
+! array constructions from x and y.
+!
+! CHECK-LABEL: func @_QPtest20b(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<2x2xi32>> {{.*}}, %[[ARG2:.*]]: !fir.ref<!fir.array<4xi32>>
+! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32>
+! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) {
+! CHECK:   %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
+! CHECK:   {{.*}} = fir.array_update %[[TEMP3]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
+! CHECK: }
+! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
+! CHECK: %[[TEMP4:.*]] = fir.allocmem !fir.array<2x2xi32>
+! CHECK: %[[TEMP5:.*]] = fir.array_load %[[TEMP4]]({{.*}}) : (!fir.heap<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
+! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP6:.*]] = %[[TEMP5]]) -> (!fir.array<2x2xi32>) {
+! CHECK:   {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP7:.*]] = %[[TEMP6]]) -> (!fir.array<2x2xi32>) {
+! CHECK:     %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32
+! CHECK:     {{.*}} = fir.array_update %[[TEMP7]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32>
+! CHECK:   }
+! CHECK: }
+subroutine test20b(x, y, z)
+  integer :: x(4), y(2, 2), z(4)
+
+  z = (/x/) + (/y/)
+end subroutine
+
+! CHECK-LABEL: func @_QPtest20c(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<4xi32>> {{.*}}, %[[ARG1:.*]]: !fir.ref<!fir.array<2x2xi32>> {{.*}}
+
+! (/x/)
+! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: %[[ACX_MEM:.*]] = fir.allocmem !fir.array<4xi32>
+! CHECK: %[[ACX:.*]] = fir.array_load %[[ACX_MEM]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACX]]) -> (!fir.array<4xi32>) {
+! CHECK:   %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32
+! CHECK:   {{.*}} = fir.array_update %[[TEMP]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32>
+! CHECK: }
+! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACX_MEM2:.*]], %{{.*}} : (!fir.heap<!fir.array<4xi32>>, index) -> !fir.ref<i32>
+! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref<i32>) -> !fir.ref<i8>
+! CHECK: %[[T2:.*]] = fir.convert %[[ACX_MEM]] : (!fir.heap<!fir.array<4xi32>>) -> !fir.ref<i8>
+! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}})
+! CHECK: %[[ACX2:.*]] = fir.array_load %[[ACX_MEM2]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+
+! (/y/)
+! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
+! CHECK: %[[ACY_MEM:.*]] = fir.allocmem !fir.array<2x2xi32>
+! CHECK: %[[ACY:.*]] = fir.array_load %[[ACY_MEM]]({{.*}}) : (!fir.heap<!fir.array<2x2xi32>>, !fir.shape<2>) -> !fir.array<2x2xi32>
+! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACY]]) -> (!fir.array<2x2xi32>) {
+! CHECK:   {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP2:.*]] = %[[TEMP]]) -> (!fir.array<2x2xi32>) {
+! CHECK:     %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32
+! CHECK:     {{.*}} = fir.array_update %[[TEMP2]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32>
+! CHECK:   }
+! CHECK: }
+! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACY_MEM2:.*]], {{.*}} : (!fir.heap<!fir.array<4xi32>>, index) -> !fir.ref<i32>
+! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref<i32>) -> !fir.ref<i8>
+! CHECK: %[[T2:.*]] = fir.convert %[[ACY_MEM]] : (!fir.heap<!fir.array<2x2xi32>>) -> !fir.ref<i8>
+! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}})
+! CHECK: %[[ACY2:.*]] = fir.array_load %[[ACY_MEM2]]({{.*}}) : (!fir.heap<!fir.array<4xi32>>, !fir.shape<1>) -> !fir.array<4xi32>
+
+! (/x/) /= (/y/)
+! CHECK: %[[RES_MEM:.*]] = fir.allocmem !fir.array<4x!fir.logical<4>>
+! CHECK: %[[RES:.*]] = fir.array_load %[[RES_MEM]]({{.*}}) : (!fir.heap<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> !fir.array<4x!fir.logical<4>>
+! CHECK: %{{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[RES]]) -> (!fir.array<4x!fir.logical<4>>) {
+! CHECK:   %[[XI:.*]] = fir.array_fetch %[[ACX2]], %[[I]] : (!fir.array<4xi32>, index) -> i32
+! CHECK:   %[[YI:.*]] = fir.array_fetch %[[ACY2]], %[[I]] : (!fir.array<4xi32>, index) -> i32
+! CHECK:   %[[T1:.*]] = arith.cmpi ne, %[[XI]], %[[YI]] : i32
+! CHECK:   %[[T2:.*]] = fir.convert %[[T1]] : (i1) -> !fir.logical<4>
+! CHECK:   {{.*}} = fir.array_update %[[TEMP]], %[[T2]], %[[I]] : (!fir.array<4x!fir.logical<4>>, !fir.logical<4>, index) -> !fir.array<4x!fir.logical<4>>
+! CHECK: }
+
+! any((/x/) /= (/y/))
+! CHECK: %[[T1:.*]] = fir.embox %[[RES_MEM]]({{.*}}) : (!fir.heap<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<4x!fir.logical<4>>>
+! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (!fir.box<!fir.array<4x!fir.logical<4>>>) -> !fir.box<none>
+! CHECK: fir.call @_FortranAAny(%[[T2]], {{.*}}){{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i1
+subroutine test20c(x, y)
+  integer :: x(4), y(2, 2)
+
+  if (any((/x/) /= (/y/))) print *, "
diff erent"
+end subroutine
+
 ! CHECK: func private @_QPbar(


        


More information about the flang-commits mailing list