[flang-commits] [flang] [flang] Optimize assignments of multidimensional arrays (PR #146408)

Leandro Lupori via flang-commits flang-commits at lists.llvm.org
Thu Jul 3 07:48:17 PDT 2025


https://github.com/luporl updated https://github.com/llvm/llvm-project/pull/146408

>From c9a350f2c1099ba0c26539acf300b672f4b557a6 Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Mon, 9 Jun 2025 11:42:29 -0300
Subject: [PATCH 1/3] [flang] Optimize assignments of multidimensional arrays

Assignments of n-dimensional arrays, with trivial RHS, were
always being converted to n nested loops. For contiguous arrays,
it's possible to flatten them and use a single loop, that can
usually be better optimized by LLVM.

In a test program, using a 3-dimensional array and varying its
size, the resulting speedup was as follows (measured on Graviton4):

16K     1.09
64K     1.40
128K    1.90
256K    1.91
512K    1.00

For sizes above or equal to 512K no improvement was observed.
It looks like LLVM stops trying to perform aggressive loop
unrolling at a certain threshold and just uses nested loops
instead. Larger sizes won't fit on L1 and L2 caches too.

This was noticed while profiling 527.cam4_r. This optimization
makes aer_rad_props slightly faster, but unfortunately it
practically doesn't change 527.cam4_r total execution time.
---
 .../Transforms/OptimizedBufferization.cpp     | 43 ++++++++++++---
 flang/test/HLFIR/opt-scalar-assign.fir        | 53 ++++++++++++++-----
 2 files changed, 77 insertions(+), 19 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 91df8672c20d9..e88991b801415 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -786,13 +786,42 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
   mlir::Value shape = hlfir::genShape(loc, builder, lhs);
   llvm::SmallVector<mlir::Value> extents =
       hlfir::getIndexExtents(loc, builder, shape);
-  hlfir::LoopNest loopNest =
-      hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
-                         flangomp::shouldUseWorkshareLowering(assign));
-  builder.setInsertionPointToStart(loopNest.body);
-  auto arrayElement =
-      hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
-  builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
+
+  if (lhs.isSimplyContiguous() && extents.size() > 1) {
+    // Flatten the array to use a single assign loop, that can be better
+    // optimized.
+    mlir::Value n = extents[0];
+    for (size_t i = 1; i < extents.size(); ++i)
+      n = builder.create<mlir::arith::MulIOp>(loc, n, extents[i]);
+    extents = {n};
+    shape = builder.genShape(loc, extents);
+    mlir::Type flatArrayType =
+        fir::ReferenceType::get(fir::SequenceType::get(eleTy, 1));
+    mlir::Value flatArray = lhs.getBase();
+    if (mlir::isa<fir::BoxType>(lhs.getType()))
+      flatArray = builder.create<fir::BoxAddrOp>(loc, flatArray);
+    flatArray = builder.createConvert(loc, flatArrayType, flatArray);
+
+    hlfir::LoopNest loopNest =
+        hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+                           flangomp::shouldUseWorkshareLowering(assign));
+    builder.setInsertionPointToStart(loopNest.body);
+
+    mlir::Value coor = builder.create<fir::ArrayCoorOp>(
+        loc, fir::ReferenceType::get(eleTy), flatArray, shape,
+        /*slice=*/mlir::Value{}, loopNest.oneBasedIndices,
+        /*typeparams=*/mlir::ValueRange{});
+    builder.create<fir::StoreOp>(loc, rhs, coor);
+  } else {
+    hlfir::LoopNest loopNest =
+        hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+                           flangomp::shouldUseWorkshareLowering(assign));
+    builder.setInsertionPointToStart(loopNest.body);
+    auto arrayElement =
+        hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
+    builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
+  }
+
   rewriter.eraseOp(assign);
   return mlir::success();
 }
diff --git a/flang/test/HLFIR/opt-scalar-assign.fir b/flang/test/HLFIR/opt-scalar-assign.fir
index 02ab02945b042..0f78d68f17ac8 100644
--- a/flang/test/HLFIR/opt-scalar-assign.fir
+++ b/flang/test/HLFIR/opt-scalar-assign.fir
@@ -12,18 +12,20 @@ func.func @_QPtest1() {
   return
 }
 // CHECK-LABEL:   func.func @_QPtest1() {
-// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_2:.*]] = arith.constant 11 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 13 : index
-// CHECK:           %[[VAL_4:.*]] = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
-// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_2]], %[[VAL_3]] : (index, index) -> !fir.shape<2>
-// CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_5]]) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
-// CHECK:           fir.do_loop %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_3]] step %[[VAL_0]] unordered {
-// CHECK:             fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] unordered {
-// CHECK:               %[[VAL_9:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_8]], %[[VAL_7]])  : (!fir.ref<!fir.array<11x13xf32>>, index, index) -> !fir.ref<f32>
-// CHECK:               hlfir.assign %[[VAL_1]] to %[[VAL_9]] : f32, !fir.ref<f32>
-// CHECK:             }
+// CHECK:           %[[VAL_0:.*]] = arith.constant 143 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_3:.*]] = arith.constant 11 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 13 : index
+// CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
+// CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_3]], %[[VAL_4]] : (index, index) -> !fir.shape<2>
+// CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_6]]) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
+
+// CHECK:           %[[VAL_8:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_7]]#0 : (!fir.ref<!fir.array<11x13xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:           fir.do_loop %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_0]] step %[[VAL_1]] unordered {
+// CHECK:             %[[VAL_11:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_8]]) %[[VAL_10]] : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_2]] to %[[VAL_11]] : !fir.ref<f32>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -129,3 +131,30 @@ func.func @_QPtest5(%arg0: !fir.ref<!fir.array<77xcomplex<f32>>> {fir.bindc_name
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
+
+func.func @_QPtest6(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {fir.bindc_name = "x"}) {
+  %c0_i32 = arith.constant 0 : i32
+  %0:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest6Ex"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
+  hlfir.assign %c0_i32 to %0#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
+  return
+}
+
+// CHECK-LABEL:   func.func @_QPtest6(
+// CHECK-SAME:                        %[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {fir.bindc_name = "x"}) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest6Ex"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
+// CHECK:           %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
+// CHECK:           %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_5]], %[[VAL_2]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_7:.*]]:3 = fir.box_dims %[[VAL_5]], %[[VAL_1]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_6]]#1, %[[VAL_7]]#1 : index
+// CHECK:           %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_10:.*]] = fir.box_addr %[[VAL_5]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.heap<!fir.array<?x?xi32>>
+// CHECK:           %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.heap<!fir.array<?x?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+// CHECK:           fir.do_loop %[[VAL_12:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_1]] unordered {
+// CHECK:             %[[VAL_13:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_9]]) %[[VAL_12]] : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_13]] : !fir.ref<i32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }

>From cc08a23fd65f7fcbb8597c499956f7045ea9f40b Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Wed, 2 Jul 2025 18:45:50 -0300
Subject: [PATCH 2/3] Use hlfir.designate and hlfir.assign

---
 .../Transforms/OptimizedBufferization.cpp     | 56 ++++++++++++++-----
 flang/test/HLFIR/opt-scalar-assign.fir        | 19 +++----
 2 files changed, 50 insertions(+), 25 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index e88991b801415..a6a2173af20db 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -21,6 +21,7 @@
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "flang/Optimizer/Support/Utils.h"
 #include "flang/Optimizer/Transforms/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Dominance.h"
@@ -758,6 +759,16 @@ class BroadcastAssignBufferization
                   mlir::PatternRewriter &rewriter) const override;
 };
 
+static bool isAllocatableArray(mlir::Type ty) {
+  auto boxTy = mlir::dyn_cast<fir::BoxType>(ty);
+  if (!boxTy)
+    return false;
+  auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getElementType());
+  if (!heapTy)
+    return false;
+  return mlir::isa<fir::SequenceType>(heapTy.getElementType());
+}
+
 llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
     hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
   // Since RHS is a scalar and LHS is an array, LHS must be allocated
@@ -787,31 +798,48 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
   llvm::SmallVector<mlir::Value> extents =
       hlfir::getIndexExtents(loc, builder, shape);
 
-  if (lhs.isSimplyContiguous() && extents.size() > 1) {
+  bool isArrayRef =
+      mlir::isa<fir::SequenceType>(fir::unwrapRefType(lhs.getType()));
+  if (lhs.isSimplyContiguous() && extents.size() > 1 &&
+      (isArrayRef || isAllocatableArray(lhs.getType()))) {
     // Flatten the array to use a single assign loop, that can be better
     // optimized.
     mlir::Value n = extents[0];
     for (size_t i = 1; i < extents.size(); ++i)
       n = builder.create<mlir::arith::MulIOp>(loc, n, extents[i]);
-    extents = {n};
-    shape = builder.genShape(loc, extents);
-    mlir::Type flatArrayType =
-        fir::ReferenceType::get(fir::SequenceType::get(eleTy, 1));
+    llvm::SmallVector<mlir::Value> flatExtents = {n};
+
+    mlir::Type flatArrayType;
     mlir::Value flatArray = lhs.getBase();
-    if (mlir::isa<fir::BoxType>(lhs.getType()))
-      flatArray = builder.create<fir::BoxAddrOp>(loc, flatArray);
-    flatArray = builder.createConvert(loc, flatArrayType, flatArray);
+    if (isArrayRef) {
+      // Array references must have fixed shape, when used in assignments.
+      int64_t flatExtent = 1;
+      for (const mlir::Value &extent : extents) {
+        mlir::Operation *op = extent.getDefiningOp();
+        assert(op && "no defining operation for constant array extent");
+        flatExtent *= fir::toInt(mlir::cast<mlir::arith::ConstantOp>(*op));
+      }
+
+      flatArrayType =
+          fir::ReferenceType::get(fir::SequenceType::get({flatExtent}, eleTy));
+      flatArray = builder.createConvert(loc, flatArrayType, flatArray);
+    } else {
+      shape = builder.genShape(loc, flatExtents);
+      flatArrayType = fir::BoxType::get(
+          fir::HeapType::get(fir::SequenceType::get(eleTy, 1)));
+      flatArray = builder.create<fir::ReboxOp>(loc, flatArrayType, flatArray,
+                                               shape, /*slice=*/mlir::Value{});
+    }
 
     hlfir::LoopNest loopNest =
-        hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+        hlfir::genLoopNest(loc, builder, flatExtents, /*isUnordered=*/true,
                            flangomp::shouldUseWorkshareLowering(assign));
     builder.setInsertionPointToStart(loopNest.body);
 
-    mlir::Value coor = builder.create<fir::ArrayCoorOp>(
-        loc, fir::ReferenceType::get(eleTy), flatArray, shape,
-        /*slice=*/mlir::Value{}, loopNest.oneBasedIndices,
-        /*typeparams=*/mlir::ValueRange{});
-    builder.create<fir::StoreOp>(loc, rhs, coor);
+    mlir::Value arrayElement =
+        builder.create<hlfir::DesignateOp>(loc, fir::ReferenceType::get(eleTy),
+                                           flatArray, loopNest.oneBasedIndices);
+    builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
   } else {
     hlfir::LoopNest loopNest =
         hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
diff --git a/flang/test/HLFIR/opt-scalar-assign.fir b/flang/test/HLFIR/opt-scalar-assign.fir
index 0f78d68f17ac8..cbcdd579ff9e2 100644
--- a/flang/test/HLFIR/opt-scalar-assign.fir
+++ b/flang/test/HLFIR/opt-scalar-assign.fir
@@ -20,12 +20,10 @@ func.func @_QPtest1() {
 // CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
 // CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_3]], %[[VAL_4]] : (index, index) -> !fir.shape<2>
 // CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_6]]) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
-
-// CHECK:           %[[VAL_8:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_7]]#0 : (!fir.ref<!fir.array<11x13xf32>>) -> !fir.ref<!fir.array<?xf32>>
-// CHECK:           fir.do_loop %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_0]] step %[[VAL_1]] unordered {
-// CHECK:             %[[VAL_11:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_8]]) %[[VAL_10]] : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
-// CHECK:             fir.store %[[VAL_2]] to %[[VAL_11]] : !fir.ref<f32>
+// CHECK:           %[[VAL_8:.*]] = fir.convert %[[VAL_7]]#0 : (!fir.ref<!fir.array<11x13xf32>>) -> !fir.ref<!fir.array<143xf32>>
+// CHECK:           fir.do_loop %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_0]] step %[[VAL_1]] unordered {
+// CHECK:             %[[VAL_10:.*]] = hlfir.designate %[[VAL_8]] (%[[VAL_9]]) : (!fir.ref<!fir.array<143xf32>>, index) -> !fir.ref<f32>
+// CHECK:             hlfir.assign %[[VAL_2]] to %[[VAL_10]] : f32, !fir.ref<f32>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -150,11 +148,10 @@ func.func @_QPtest6(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {f
 // CHECK:           %[[VAL_7:.*]]:3 = fir.box_dims %[[VAL_5]], %[[VAL_1]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, index) -> (index, index, index)
 // CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_6]]#1, %[[VAL_7]]#1 : index
 // CHECK:           %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_10:.*]] = fir.box_addr %[[VAL_5]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.heap<!fir.array<?x?xi32>>
-// CHECK:           %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.heap<!fir.array<?x?xi32>>) -> !fir.ref<!fir.array<?xi32>>
-// CHECK:           fir.do_loop %[[VAL_12:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_1]] unordered {
-// CHECK:             %[[VAL_13:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_9]]) %[[VAL_12]] : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
-// CHECK:             fir.store %[[VAL_3]] to %[[VAL_13]] : !fir.ref<i32>
+// CHECK:           %[[VAL_10:.*]] = fir.rebox %[[VAL_5]](%[[VAL_9]]) : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+// CHECK:           fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_1]] unordered {
+// CHECK:             %[[VAL_12:.*]] = hlfir.designate %[[VAL_10]] (%[[VAL_11]]) : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> !fir.ref<i32>
+// CHECK:             hlfir.assign %[[VAL_3]] to %[[VAL_12]] : i32, !fir.ref<i32>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }

>From 2eb7b7e788be2d7b59ff7c58ec26087554c8a8eb Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Thu, 3 Jul 2025 11:34:13 -0300
Subject: [PATCH 3/3] Remoeve isAllocatableArray check

---
 .../Transforms/OptimizedBufferization.cpp     | 28 +++++--------------
 flang/test/HLFIR/opt-scalar-assign.fir        |  4 +--
 2 files changed, 9 insertions(+), 23 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index a6a2173af20db..54892ef99bf58 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -759,16 +759,6 @@ class BroadcastAssignBufferization
                   mlir::PatternRewriter &rewriter) const override;
 };
 
-static bool isAllocatableArray(mlir::Type ty) {
-  auto boxTy = mlir::dyn_cast<fir::BoxType>(ty);
-  if (!boxTy)
-    return false;
-  auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getElementType());
-  if (!heapTy)
-    return false;
-  return mlir::isa<fir::SequenceType>(heapTy.getElementType());
-}
-
 llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
     hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
   // Since RHS is a scalar and LHS is an array, LHS must be allocated
@@ -798,10 +788,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
   llvm::SmallVector<mlir::Value> extents =
       hlfir::getIndexExtents(loc, builder, shape);
 
-  bool isArrayRef =
-      mlir::isa<fir::SequenceType>(fir::unwrapRefType(lhs.getType()));
-  if (lhs.isSimplyContiguous() && extents.size() > 1 &&
-      (isArrayRef || isAllocatableArray(lhs.getType()))) {
+  if (lhs.isSimplyContiguous() && extents.size() > 1) {
     // Flatten the array to use a single assign loop, that can be better
     // optimized.
     mlir::Value n = extents[0];
@@ -811,7 +798,12 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
 
     mlir::Type flatArrayType;
     mlir::Value flatArray = lhs.getBase();
-    if (isArrayRef) {
+    if (mlir::isa<fir::BoxType>(lhs.getType())) {
+      shape = builder.genShape(loc, flatExtents);
+      flatArrayType = fir::BoxType::get(fir::SequenceType::get(eleTy, 1));
+      flatArray = builder.create<fir::ReboxOp>(loc, flatArrayType, flatArray,
+                                               shape, /*slice=*/mlir::Value{});
+    } else {
       // Array references must have fixed shape, when used in assignments.
       int64_t flatExtent = 1;
       for (const mlir::Value &extent : extents) {
@@ -823,12 +815,6 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
       flatArrayType =
           fir::ReferenceType::get(fir::SequenceType::get({flatExtent}, eleTy));
       flatArray = builder.createConvert(loc, flatArrayType, flatArray);
-    } else {
-      shape = builder.genShape(loc, flatExtents);
-      flatArrayType = fir::BoxType::get(
-          fir::HeapType::get(fir::SequenceType::get(eleTy, 1)));
-      flatArray = builder.create<fir::ReboxOp>(loc, flatArrayType, flatArray,
-                                               shape, /*slice=*/mlir::Value{});
     }
 
     hlfir::LoopNest loopNest =
diff --git a/flang/test/HLFIR/opt-scalar-assign.fir b/flang/test/HLFIR/opt-scalar-assign.fir
index cbcdd579ff9e2..74cdcd9622adb 100644
--- a/flang/test/HLFIR/opt-scalar-assign.fir
+++ b/flang/test/HLFIR/opt-scalar-assign.fir
@@ -148,9 +148,9 @@ func.func @_QPtest6(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {f
 // CHECK:           %[[VAL_7:.*]]:3 = fir.box_dims %[[VAL_5]], %[[VAL_1]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, index) -> (index, index, index)
 // CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_6]]#1, %[[VAL_7]]#1 : index
 // CHECK:           %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_10:.*]] = fir.rebox %[[VAL_5]](%[[VAL_9]]) : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+// CHECK:           %[[VAL_10:.*]] = fir.rebox %[[VAL_5]](%[[VAL_9]]) : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
 // CHECK:           fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_1]] unordered {
-// CHECK:             %[[VAL_12:.*]] = hlfir.designate %[[VAL_10]] (%[[VAL_11]]) : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> !fir.ref<i32>
+// CHECK:             %[[VAL_12:.*]] = hlfir.designate %[[VAL_10]] (%[[VAL_11]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
 // CHECK:             hlfir.assign %[[VAL_3]] to %[[VAL_12]] : i32, !fir.ref<i32>
 // CHECK:           }
 // CHECK:           return



More information about the flang-commits mailing list