[flang-commits] [flang] e0adee8 - [flang] Correct folding of CSHIFT and EOSHIFT for DIM>1

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Jun 3 19:00:04 PDT 2022


Author: Peter Klausler
Date: 2022-06-03T18:59:44-07:00
New Revision: e0adee8481623613933551e00adcd9ddea18d889

URL: https://github.com/llvm/llvm-project/commit/e0adee8481623613933551e00adcd9ddea18d889
DIFF: https://github.com/llvm/llvm-project/commit/e0adee8481623613933551e00adcd9ddea18d889.diff

LOG: [flang] Correct folding of CSHIFT and EOSHIFT for DIM>1

The algorithm was wrong for higher dimensions, and so were
the expected test results.  Rework.

Differential Revision: https://reviews.llvm.org/D127018

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-implementation.h
    flang/test/Evaluate/folding23.f90
    flang/test/Evaluate/folding27.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 97abfc688d63..c0dee020f8fb 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -613,26 +613,33 @@ template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
     }
     if (ok) {
       std::vector<Scalar<T>> resultElements;
-      ConstantSubscripts arrayAt{array->lbounds()};
-      ConstantSubscript dimLB{arrayAt[zbDim]};
+      ConstantSubscripts arrayLB{array->lbounds()};
+      ConstantSubscripts arrayAt{arrayLB};
+      ConstantSubscript &dimIndex{arrayAt[zbDim]};
+      ConstantSubscript dimLB{dimIndex}; // initial value
       ConstantSubscript dimExtent{array->shape()[zbDim]};
-      ConstantSubscripts shiftAt{shift->lbounds()};
-      for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
-        ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
-        ConstantSubscript zbDimIndex{shiftCount % dimExtent};
-        if (zbDimIndex < 0) {
-          zbDimIndex += dimExtent;
-        }
-        for (ConstantSubscript j{0}; j < dimExtent; ++j) {
-          arrayAt[zbDim] = dimLB + zbDimIndex;
-          resultElements.push_back(array->At(arrayAt));
-          if (++zbDimIndex == dimExtent) {
-            zbDimIndex = 0;
+      ConstantSubscripts shiftLB{shift->lbounds()};
+      for (auto n{GetSize(array->shape())}; n > 0; --n) {
+        ConstantSubscript origDimIndex{dimIndex};
+        ConstantSubscripts shiftAt;
+        if (shift->Rank() > 0) {
+          int k{0};
+          for (int j{0}; j < rank; ++j) {
+            if (j != zbDim) {
+              shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]);
+            }
           }
         }
-        arrayAt[zbDim] = dimLB + std::max<ConstantSubscript>(dimExtent, 1) - 1;
+        ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
+        dimIndex = dimLB + ((dimIndex - dimLB + shiftCount) % dimExtent);
+        if (dimIndex < dimLB) {
+          dimIndex += dimExtent;
+        } else if (dimIndex >= dimLB + dimExtent) {
+          dimIndex -= dimExtent;
+        }
+        resultElements.push_back(array->At(arrayAt));
+        dimIndex = origDimIndex;
         array->IncrementSubscripts(arrayAt);
-        shift->IncrementSubscripts(shiftAt);
       }
       return Expr<T>{PackageConstant<T>(
           std::move(resultElements), *array, array->shape())};
@@ -714,42 +721,57 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
     }
     if (ok) {
       std::vector<Scalar<T>> resultElements;
-      ConstantSubscripts arrayAt{array->lbounds()};
-      ConstantSubscript dimLB{arrayAt[zbDim]};
+      ConstantSubscripts arrayLB{array->lbounds()};
+      ConstantSubscripts arrayAt{arrayLB};
+      ConstantSubscript &dimIndex{arrayAt[zbDim]};
+      ConstantSubscript dimLB{dimIndex}; // initial value
       ConstantSubscript dimExtent{array->shape()[zbDim]};
-      ConstantSubscripts shiftAt{shift->lbounds()};
-      ConstantSubscripts boundaryAt;
+      ConstantSubscripts shiftLB{shift->lbounds()};
+      ConstantSubscripts boundaryLB;
       if (boundary) {
-        boundaryAt = boundary->lbounds();
+        boundaryLB = boundary->lbounds();
       }
-      for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
+      for (auto n{GetSize(array->shape())}; n > 0; --n) {
+        ConstantSubscript origDimIndex{dimIndex};
+        ConstantSubscripts shiftAt;
+        if (shift->Rank() > 0) {
+          int k{0};
+          for (int j{0}; j < rank; ++j) {
+            if (j != zbDim) {
+              shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]);
+            }
+          }
+        }
         ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
-        for (ConstantSubscript j{0}; j < dimExtent; ++j) {
-          ConstantSubscript zbAt{shiftCount + j};
-          if (zbAt >= 0 && zbAt < dimExtent) {
-            arrayAt[zbDim] = dimLB + zbAt;
-            resultElements.push_back(array->At(arrayAt));
-          } else if (boundary) {
-            resultElements.push_back(boundary->At(boundaryAt));
-          } else if constexpr (T::category == TypeCategory::Integer ||
-              T::category == TypeCategory::Real ||
-              T::category == TypeCategory::Complex ||
-              T::category == TypeCategory::Logical) {
-            resultElements.emplace_back();
-          } else if constexpr (T::category == TypeCategory::Character) {
-            auto len{static_cast<std::size_t>(array->LEN())};
-            typename Scalar<T>::value_type space{' '};
-            resultElements.emplace_back(len, space);
-          } else {
-            DIE("no derived type boundary");
+        dimIndex += shiftCount;
+        if (dimIndex >= dimLB && dimIndex < dimLB + dimExtent) {
+          resultElements.push_back(array->At(arrayAt));
+        } else if (boundary) {
+          ConstantSubscripts boundaryAt;
+          if (boundary->Rank() > 0) {
+            for (int j{0}; j < rank; ++j) {
+              int k{0};
+              if (j != zbDim) {
+                boundaryAt.emplace_back(
+                    boundaryLB[k++] + arrayAt[j] - arrayLB[j]);
+              }
+            }
           }
+          resultElements.push_back(boundary->At(boundaryAt));
+        } else if constexpr (T::category == TypeCategory::Integer ||
+            T::category == TypeCategory::Real ||
+            T::category == TypeCategory::Complex ||
+            T::category == TypeCategory::Logical) {
+          resultElements.emplace_back();
+        } else if constexpr (T::category == TypeCategory::Character) {
+          auto len{static_cast<std::size_t>(array->LEN())};
+          typename Scalar<T>::value_type space{' '};
+          resultElements.emplace_back(len, space);
+        } else {
+          DIE("no derived type boundary");
         }
-        arrayAt[zbDim] = dimLB + std::max<ConstantSubscript>(dimExtent, 1) - 1;
+        dimIndex = origDimIndex;
         array->IncrementSubscripts(arrayAt);
-        shift->IncrementSubscripts(shiftAt);
-        if (boundary) {
-          boundary->IncrementSubscripts(boundaryAt);
-        }
       }
       return Expr<T>{PackageConstant<T>(
           std::move(resultElements), *array, array->shape())};

diff  --git a/flang/test/Evaluate/folding23.f90 b/flang/test/Evaluate/folding23.f90
index f31478ed3c5e..c25d2fc93982 100644
--- a/flang/test/Evaluate/folding23.f90
+++ b/flang/test/Evaluate/folding23.f90
@@ -9,7 +9,7 @@ module m
   logical, parameter :: test_eoshift_3 = all(eoshift([1., 2., 3.], 1) == [2., 3., 0.])
   logical, parameter :: test_eoshift_4 = all(eoshift(['ab', 'cd', 'ef'], -1, 'x') == ['x ', 'ab', 'cd'])
   logical, parameter :: test_eoshift_5 = all([eoshift(arr, 1, dim=1)] == [2, 0, 4, 0, 6, 0])
-  logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 5, 0, 4, 6, 0])
+  logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 4, 5, 6, 0, 0])
   logical, parameter :: test_eoshift_7 = all([eoshift(arr, [1, -1, 0])] == [2, 0, 0, 3, 5, 6])
-  logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 5, 0, 0, 2, 4])
+  logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 0, 5, 2, 0, 4])
 end module

diff  --git a/flang/test/Evaluate/folding27.f90 b/flang/test/Evaluate/folding27.f90
index 0d3d333c0f10..43699184f31a 100644
--- a/flang/test/Evaluate/folding27.f90
+++ b/flang/test/Evaluate/folding27.f90
@@ -9,7 +9,7 @@ module m
   logical, parameter :: test_cshift_3 = all(cshift([1, 2, 3], 4) == [2, 3, 1])
   logical, parameter :: test_cshift_4 = all(cshift([1, 2, 3], -1) == [3, 1, 2])
   logical, parameter :: test_cshift_5 = all([cshift(arr, 1, dim=1)] == [2, 1, 4, 3, 6, 5])
-  logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 5, 1, 4, 6, 2])
+  logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 4, 5, 6, 1, 2])
   logical, parameter :: test_cshift_7 = all([cshift(arr, [1, 2, 3])] == [2, 1, 3, 4, 6, 5])
-  logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 5, 1, 6, 2, 4])
+  logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 6, 5, 2, 1, 4])
 end module


        


More information about the flang-commits mailing list