[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

ioana ghiban llvmlistbot at llvm.org
Thu Apr 23 04:49:06 PDT 2026


https://github.com/ioghiban updated https://github.com/llvm/llvm-project/pull/188459

>From 4f0fd73ab6aa1dec0a7c48bba767e0602a39a5dc Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Tue, 24 Mar 2026 17:34:52 +0100
Subject: [PATCH 1/4] [memref] Simplify loads from reinterpret_cast of 1D
 contiguous memrefs

Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and tests before submission.
---
 .../Transforms/ElideReinterpretCast.cpp       | 241 +++++++++++++-
 .../MemRef/elide-reinterpret-cast.mlir        | 309 +++++++++++++++++-
 2 files changed, 548 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index 01632c6ea1579..12b0b0bb7bd6a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Repeated.h"
@@ -196,6 +197,237 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
+static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
+
+static bool isPureRankReshape(memref::ReinterpretCastOp rc, memref::LoadOp op) {
+  auto inputTy = cast<MemRefType>(rc.getSource().getType());
+  auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+  // This fold only handles reinterpret_casts that behave like pure rank
+  // reshapes of a single logical dimension:
+  //
+  //   - all metadata is static
+  //   - offset is 0
+  //   - source/result each have at most one non-unit dim
+  //   - if a non-unit dim exists, it is at the left or right boundary
+  //
+  // Examples accepted by this shape restriction:
+  //   memref<999xf32>       <-> memref<1x1x999xf32>
+  //   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
+  //   memref<100x1xf32>     <-> memref<100x1x1xf32>
+  //
+  // General reinterpret_casts are intentionally rejected.
+
+  auto offsets = rc.getStaticOffsets();
+  assert(offsets.size() == 1 && "Expecting single offset");
+
+  // The rewrite drops the reinterpret_cast and remaps indices directly to the
+  // source memref. That is only correct if there is no storage shift.
+  if (ShapedType::isDynamic(offsets[0]) || offsets[0] != 0)
+    return false;
+
+  auto sizes = rc.getStaticSizes();
+  auto strides = rc.getStaticStrides();
+
+  // Require fully static metadata. The fold relies on knowing exactly which
+  // dimensions are unit dimensions and which indices may be ignored.
+  if (llvm::any_of(sizes, ShapedType::isDynamic))
+    return false;
+  if (llvm::any_of(strides, ShapedType::isDynamic))
+    return false;
+
+  // Count non-unit dims and remember their positions.
+  //
+  // The rewrite supports shapes with at most one non-unit dimension.
+  // This excludes underlying multi-dimensional layouts and keeps the
+  // fold limited to unit-dim insertion/removal reshapes.
+  unsigned inputRank = inputTy.getRank();
+  int inputNonUnitCount = 0;
+  int64_t inputNonUnitSize = 1;
+  unsigned inputNonUnitPos = 0;
+  for (unsigned i = 0; i < inputRank; ++i) {
+    if (inputTy.getDimSize(i) != 1) {
+      ++inputNonUnitCount;
+      inputNonUnitPos = i;
+      inputNonUnitSize = inputTy.getDimSize(i);
+    }
+  }
+
+  unsigned outputRank = outputTy.getRank();
+  int outputNonUnitCount = 0;
+  int64_t outputNonUnitSize = 1;
+  unsigned outputNonUnitPos = 0;
+  for (unsigned i = 0; i < outputRank; ++i) {
+    if (outputTy.getDimSize(i) != 1) {
+      ++outputNonUnitCount;
+      outputNonUnitPos = i;
+      outputNonUnitSize = outputTy.getDimSize(i);
+    }
+  }
+
+  // Reject reshapes with > 1 non-unit-dimension.
+  //
+  // The source and result must have the same number of non-unit dimensions:
+  // either both are all-ones, or both have exactly one non-unit dimension.
+  if (inputNonUnitCount > 1 || outputNonUnitCount > 1 ||
+      inputNonUnitCount != outputNonUnitCount)
+    return false;
+
+  // If there is a non-unit dimension, it must live at the same boundary
+  // (first or last dimension) on both input and output memrefs.
+  // The rewrite logic for preserving the load index is exclusive to these
+  // cases.
+  if (inputNonUnitCount == 1) {
+    auto isBoundary = [](unsigned pos, unsigned rank) {
+      return pos == 0 || pos == rank - 1;
+    };
+    if (!isBoundary(inputNonUnitPos, inputRank) ||
+        !isBoundary(outputNonUnitPos, outputRank))
+      return false;
+  }
+
+  // Size of non-unit dimension must be the same
+  if (inputNonUnitCount == 1 && outputNonUnitCount == 1 &&
+      inputNonUnitSize != outputNonUnitSize)
+    return false;
+
+  SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+  SmallVector<unsigned> nonZeroIdxPositions;
+  nonZeroIdxPositions.reserve(idxs.size());
+
+  // Record non-zero indices.
+  //
+  // During rank expansion, the rewrite drops the extra unit-dimension indices.
+  // That is only semantics-preserving if every dropped index is zero.
+  for (auto [pos, idx] : llvm::enumerate(idxs)) {
+    if (!isConstZero(idx))
+      nonZeroIdxPositions.push_back(pos);
+  }
+
+  // Position of the unique non-unit dim in the output, if present:
+  //   - 0            for shapes like [N, 1, 1]
+  //   - outputRank-1 for shapes like [1, 1, N]
+  //
+  // For the all-ones case, treat it like the "non-unit on the right" case.
+  unsigned nonUnitDimPos =
+      (outputNonUnitCount == 1 && outputTy.getDimSize(0) != 1) ? 0
+                                                               : outputRank - 1;
+
+  if (outputRank >= inputRank) {
+    // Rank expansion case.
+    //
+    // The rewrite keeps only inputRank indices. Any non-zero index in an
+    // expanded unit dimension that would be discarded makes the fold invalid.
+    if (nonUnitDimPos == 0) {
+      // Expansion on the right: keep the leftmost inputRank indices.
+      // Therefore any non-zero index in the suffix would be lost.
+      for (unsigned pos : nonZeroIdxPositions) {
+        if (pos >= inputRank)
+          return false;
+      }
+    } else {
+      // Expansion on the left: keep the rightmost inputRank indices.
+      // Therefore any non-zero index in the prefix would be lost.
+      unsigned firstValidPos = outputRank - inputRank;
+      for (unsigned pos : nonZeroIdxPositions) {
+        if (pos < firstValidPos)
+          return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::LoadOp op,
+                                PatternRewriter &rewriter) const override {
+    auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+    if (!rc)
+      return failure();
+
+    // This fold is only correct for the narrow "pure rank reshape of a single
+    // logical dimension" cases accepted by isPureRankReshape().
+    if (!isPureRankReshape(rc, op))
+      return failure();
+
+    auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
+    auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
+
+    int64_t rcOutputRank = rcOutputTy.getRank();
+    int64_t rcInputRank = rcInputTy.getRank();
+
+    SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+    SmallVector<Value> rcInputIdxs;
+
+    // The fold only supports reshapes with at most one non-unit dimension,
+    // located at the left or right boundary.
+    //
+    // The higher-rank side tells which side the reshape has expanded/collapsed.
+    //
+    //   expansion: rcOutput has the higher rank
+    //   collapse : rcInput has the higher rank
+    //
+    // Example:
+    //   memref<999>     -> memref<1x1x999>   : extra dims to the left
+    //   memref<999x1x1> -> memref<999>       : extra dims to the right
+    MemRefType expandedTy =
+        rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
+    bool nonUnitOnLeft = expandedTy.getDimSize(0) != 1;
+
+    if (rcOutputRank >= rcInputRank) {
+      // Rank expansion:
+      //   memref<N>   -> memref<1x1xN>   : keep the last rcInputRank indices
+      //   memref<N>   -> memref<Nx1x1>   : keep the first rcInputRank indices
+      //
+      // Any discarded indices are known to be zero from isPureRankReshape().
+      if (nonUnitOnLeft) {
+        for (int64_t dim = 0; dim < rcInputRank; ++dim)
+          rcInputIdxs.push_back(idxs[dim]);
+      } else {
+        for (int64_t dim = 0; dim < rcInputRank; ++dim)
+          rcInputIdxs.push_back(idxs[rcOutputRank - rcInputRank + dim]);
+      }
+    } else {
+      // Rank collapse:
+      //   memref<1x1xN> -> memref<N>      : reinsert leading zeros
+      //   memref<Nx1x1> -> memref<N>      : reinsert trailing zeros
+      //
+      // The collapsed-away dimensions are unit dims, so readding them with
+      // zero indices preserves semantics.
+      Value c0 = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+      int64_t rankDiff = rcInputRank - rcOutputRank;
+
+      if (nonUnitOnLeft) {
+        rcInputIdxs.append(idxs.begin(), idxs.end());
+        rcInputIdxs.append(rankDiff, c0);
+      } else {
+        rcInputIdxs.append(rankDiff, c0);
+        rcInputIdxs.append(idxs.begin(), idxs.end());
+      }
+    }
+
+    // Sanity check: rewritten load must index the source memref with exactly
+    // as many indices as the rank.
+    if ((int64_t)rcInputIdxs.size() != rcInputRank)
+      return failure();
+
+    auto rcInput = rc.getSource();
+    // If the only user of rc is the current Op (which is about to be erased),
+    // we can safely erase it.
+    if (rc.getResult().hasOneUse())
+      rewriter.eraseOp(rc);
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, rcInput, rcInputIdxs);
+
+    // Do not erase the reinterpret_cast here. After the load is rewritten it
+    // may become dead, and canonical DCE can remove it.
+    return success();
+  }
+};
+
 struct ElideReinterpretCastPass
     : public memref::impl::ElideReinterpretCastPassBase<
           ElideReinterpretCastPass> {
@@ -211,6 +443,12 @@ struct ElideReinterpretCastPass
         return true;
       return !isScalarSlice(rc);
     });
+    target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp op) {
+      auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+      if (!rc)
+        return true;
+      return !isPureRankReshape(rc, op);
+    });
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -222,5 +460,6 @@ struct ElideReinterpretCastPass
 
 void mlir::memref::populateElideReinterpretCastPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
+  patterns.add<CopyToScalarLoadAndStore, FoldReinterpretCastLoad>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index da47562e9c0d6..5733c97ea8f3b 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -memref-elide-reinterpret-cast %s \
+// RUN: | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // Positive tests
@@ -220,3 +221,309 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
   : memref<1x1xf32> to memref<1x1xf32>
   return
 }
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+      to memref<1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<999xi64> to memref<1x1x999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<1x1x999xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @reshape_expand_left_inner_unit_dims(
+    %src : memref<1x108xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+      : memref<1x108xf32> to memref<1x1x1x108xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+    : memref<1x1x1x108xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @reshape_collapse_left_inner_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 100], strides: [100, 1]
+      : memref<1x1x1x100xf32> to memref<1x100xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1x100xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x100xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
+      : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<999x1x1xi64,
+    strided<[1, 999, 999]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999x1x1xi64>) {
+func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<999x1x1xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @reshape_expand_right_inner_unit_dims(
+    %src : memref<108x1xf32>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<100x1x1x1xf32>) {
+func.func private @reshape_collapse_right_inner_unit_dims(
+    %src : memref<100x1x1x1xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [100, 1], strides: [1, 100]
+      : memref<100x1x1x1xf32> to memref<100x1xf32, strided<[1, 100]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]], %[[C0_0]]] : memref<100x1x1x1xf32>
+  %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<100x1xf32,
+    strided<[1, 100]>>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_reshape_nonzero_offset(
+    %src : memref<1xi64>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1]
+    : memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_dynamic_shape(
+// CHECK-SAME:   %[[DIM:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<?xi64>
+func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
+    %src : memref<?xi64>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, %[[DIM]]], strides: [1, 1] : memref<?xi64> to memref<1x?xi64>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, %dim], strides: [1, 1]
+      : memref<?xi64> to memref<1x?xi64>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %i] : memref<1x?xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_dynamic_stride(
+// CHECK-SAME:   %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xi64>
+func.func private @negative_reshape_dynamic_stride(%stride0: index,
+    %stride1: index, %src : memref<1x108xi64>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1], strides: [%[[STR0]], %[[STR1]]] : memref<1x108xi64> to memref<1x1xi64, strided<[?, ?]>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+    : memref<1x108xi64>
+      to memref<1x1xi64, strided<[?, ?]>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1]
+    : memref<1x1xi64, strided<[?, ?]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_multiple_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<2x1x1x100xf32>) {
+func.func private @negative_reshape_multiple_non_unit_dims(
+    %src : memref<2x1x1x100xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [2, 100], strides: [100, 1] : memref<2x1x1x100xf32> to memref<2x100xf32>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [2, 100], strides: [100, 1]
+      : memref<2x1x1x100xf32> to memref<2x100xf32>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<2x100xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_diff_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_reshape_diff_non_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 101], strides: [101, 1] : memref<1x1x1x100xf32> to memref<1x101xf32>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 101], strides: [101, 1]
+      : memref<1x1x1x100xf32> to memref<1x101xf32>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x101xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_inner_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_reshape_inner_non_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100] : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
+      : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
+    strided<[100, 1, 100]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_expand_left_discarded_indices(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @negative_reshape_expand_left_discarded_indices(
+    %src : memref<1x108xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1] : memref<1x108xf32> to memref<1x1x1x108xf32>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+      : memref<1x108xf32> to memref<1x1x1x108xf32>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+    : memref<1x1x1x108xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_expand_right_discarded_indices(
+// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @negative_reshape_expand_right_discarded_indices(
+    %src : memref<108x1xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108] : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  return
+}

>From 843edbc47c0c5fa261fa82d7e0062142e96e7ed8 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Tue, 31 Mar 2026 12:09:03 +0200
Subject: [PATCH 2/4] Address first round of comments

---
 .../Transforms/ElideReinterpretCast.cpp       | 324 ++++++++++--------
 .../MemRef/elide-reinterpret-cast.mlir        | 181 ++++++----
 2 files changed, 285 insertions(+), 220 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index 12b0b0bb7bd6a..37b2e70e43f54 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Repeated.h"
 #include <cassert>
+#include <optional>
 
 namespace mlir {
 namespace memref {
@@ -199,147 +200,164 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
 
 static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
 
-static bool isPureRankReshape(memref::ReinterpretCastOp rc, memref::LoadOp op) {
+static std::optional<int64_t> getConstantIndex(Value v) {
+  if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+    return cst.value();
+  return std::nullopt;
+}
+
+struct SingleNonUnitDimInfo {
+  bool Exists = false;
+  bool isOnLeft = true;
+  bool isOnRight = true;
+};
+
+/// Returns information about a MemRef with at most one non-unit dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit memrefs are treated as being on both boundaries.
+static std::optional<SingleNonUnitDimInfo>
+getSingleNonUnitDimInfo(MemRefType type) {
+  ArrayRef<int64_t> shape = type.getShape();
+  int64_t nonUnitCount =
+      llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+  if (nonUnitCount == 0)
+    return SingleNonUnitDimInfo{};
+  if (nonUnitCount > 1)
+    return std::nullopt;
+
+  bool isOnLeft = shape.front() != 1;
+  bool isOnRight = shape.back() != 1;
+  if (!isOnLeft && !isOnRight)
+    return std::nullopt;
+
+  return SingleNonUnitDimInfo{/*Exists=*/true, isOnLeft, isOnRight};
+}
+
+static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
+  ArrayRef<int64_t> offsets = rc.getStaticOffsets();
+  // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
+  // only a single offset. That should be fixed at the op definition level.
+  return offsets.size() == 1 && !ShapedType::isDynamic(offsets[0]) &&
+         offsets[0] == 0;
+}
+
+static bool isConstantIndexInBounds(Value idx, int64_t upperBound) {
+  if (isConstZero(idx))
+    return true;
+
+  std::optional<int64_t> idxVal = getConstantIndex(idx);
+  return idxVal && *idxVal >= 0 && *idxVal < upperBound;
+}
+
+/// Checks for pure rank expansion/collapsing of a single logical dimension:
+///   - all metadata is static
+///   - offset is 0
+///   - source/result each have at most one non-unit dim
+///   - if a non-unit dim exists, it is at the left or right boundary
+///
+/// Examples accepted by this shape restriction:
+///   memref<999xf32>       <-> memref<1x1x999xf32>
+///   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
+///   memref<100x1xf32>     <-> memref<100x1x1xf32>
+///   memref<1>             <-> memref<1x1x1>
+///
+/// General reinterpret_casts are intentionally rejected.
+static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
   auto inputTy = cast<MemRefType>(rc.getSource().getType());
   auto outputTy = cast<MemRefType>(rc.getResult().getType());
 
-  // This fold only handles reinterpret_casts that behave like pure rank
-  // reshapes of a single logical dimension:
-  //
-  //   - all metadata is static
-  //   - offset is 0
-  //   - source/result each have at most one non-unit dim
-  //   - if a non-unit dim exists, it is at the left or right boundary
-  //
-  // Examples accepted by this shape restriction:
-  //   memref<999xf32>       <-> memref<1x1x999xf32>
-  //   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
-  //   memref<100x1xf32>     <-> memref<100x1x1xf32>
-  //
-  // General reinterpret_casts are intentionally rejected.
-
-  auto offsets = rc.getStaticOffsets();
-  assert(offsets.size() == 1 && "Expecting single offset");
-
-  // The rewrite drops the reinterpret_cast and remaps indices directly to the
-  // source memref. That is only correct if there is no storage shift.
-  if (ShapedType::isDynamic(offsets[0]) || offsets[0] != 0)
+  // This check assumes the rewrite relies on "index re-use" and misses "index
+  // re-write/adjustment". Thus, storage shift and statically unknown offsets
+  // are rejected.
+  if (!hasStaticZeroOffset(rc))
     return false;
 
-  auto sizes = rc.getStaticSizes();
-  auto strides = rc.getStaticStrides();
-
-  // Require fully static metadata. The fold relies on knowing exactly which
-  // dimensions are unit dimensions and which indices may be ignored.
-  if (llvm::any_of(sizes, ShapedType::isDynamic))
-    return false;
-  if (llvm::any_of(strides, ShapedType::isDynamic))
+  // The check assumes the rewrite relies on completely static shape info.
+  if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
+      llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
     return false;
 
-  // Count non-unit dims and remember their positions.
-  //
-  // The rewrite supports shapes with at most one non-unit dimension.
-  // This excludes underlying multi-dimensional layouts and keeps the
-  // fold limited to unit-dim insertion/removal reshapes.
-  unsigned inputRank = inputTy.getRank();
-  int inputNonUnitCount = 0;
-  int64_t inputNonUnitSize = 1;
-  unsigned inputNonUnitPos = 0;
-  for (unsigned i = 0; i < inputRank; ++i) {
-    if (inputTy.getDimSize(i) != 1) {
-      ++inputNonUnitCount;
-      inputNonUnitPos = i;
-      inputNonUnitSize = inputTy.getDimSize(i);
-    }
-  }
-
-  unsigned outputRank = outputTy.getRank();
-  int outputNonUnitCount = 0;
-  int64_t outputNonUnitSize = 1;
-  unsigned outputNonUnitPos = 0;
-  for (unsigned i = 0; i < outputRank; ++i) {
-    if (outputTy.getDimSize(i) != 1) {
-      ++outputNonUnitCount;
-      outputNonUnitPos = i;
-      outputNonUnitSize = outputTy.getDimSize(i);
-    }
-  }
+  // The check assumes the rewrite supports shapes with at most one non-unit
+  // dimension. This excludes underlying multi-dimensional layouts and keeps the
+  // rewrite limited to unit-dim insertion/removal `reinterpret_cast`s.
+  std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+      getSingleNonUnitDimInfo(inputTy);
+  std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+      getSingleNonUnitDimInfo(outputTy);
+  if (!inputNonUnitDim || !outputNonUnitDim)
+    return false;
 
-  // Reject reshapes with > 1 non-unit-dimension.
-  //
-  // The source and result must have the same number of non-unit dimensions:
-  // either both are all-ones, or both have exactly one non-unit dimension.
-  if (inputNonUnitCount > 1 || outputNonUnitCount > 1 ||
-      inputNonUnitCount != outputNonUnitCount)
+  // The source and result must either both be all-ones, or both have a single
+  // non-unit dimension.
+  if (inputNonUnitDim->Exists != outputNonUnitDim->Exists)
     return false;
+  if (!inputNonUnitDim->Exists)
+    return true;
 
-  // If there is a non-unit dimension, it must live at the same boundary
-  // (first or last dimension) on both input and output memrefs.
-  // The rewrite logic for preserving the load index is exclusive to these
-  // cases.
-  if (inputNonUnitCount == 1) {
-    auto isBoundary = [](unsigned pos, unsigned rank) {
-      return pos == 0 || pos == rank - 1;
-    };
-    if (!isBoundary(inputNonUnitPos, inputRank) ||
-        !isBoundary(outputNonUnitPos, outputRank))
-      return false;
-  }
+  // If there is a non-unit dimension, it must be preserved on the same
+  // boundary. Rank-1 memrefs are accepted against either boundary.
+  return (inputNonUnitDim->isOnLeft && outputNonUnitDim->isOnLeft) ||
+         (inputNonUnitDim->isOnRight && outputNonUnitDim->isOnRight);
+}
 
-  // Size of non-unit dimension must be the same
-  if (inputNonUnitCount == 1 && outputNonUnitCount == 1 &&
-      inputNonUnitSize != outputNonUnitSize)
+/// Checks whether load indices corresponding to unit dims in the source
+/// MemRef are all 0, i.e. in-bounds.
+/// Returns false for out-of-bounds indices or non-constant indices.
+static bool areIndicesForUnitDimsInBounds(memref::LoadOp load) {
+  auto rc = load.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+  auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
+  auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
+  int64_t rcInputRank = rcInputTy.getRank();
+  int64_t rcOutputRank = rcOutputTy.getRank();
+
+  std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+      getSingleNonUnitDimInfo(rcInputTy);
+  std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+      getSingleNonUnitDimInfo(rcOutputTy);
+  if (!inputNonUnitDim || !outputNonUnitDim)
     return false;
 
-  SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
   SmallVector<unsigned> nonZeroIdxPositions;
-  nonZeroIdxPositions.reserve(idxs.size());
+  nonZeroIdxPositions.reserve(load.getIndices().size());
 
-  // Record non-zero indices.
-  //
-  // During rank expansion, the rewrite drops the extra unit-dimension indices.
-  // That is only semantics-preserving if every dropped index is zero.
-  for (auto [pos, idx] : llvm::enumerate(idxs)) {
-    if (!isConstZero(idx))
-      nonZeroIdxPositions.push_back(pos);
-  }
+  // Check if non-zero indices are out-of-bounds.
+  // Only care about indices corresponding to the load from the
+  // reinterpret_cast result.
+  for (auto [pos, idx] : llvm::enumerate(load.getIndices())) {
+    if (isConstZero(idx))
+      continue;
 
-  // Position of the unique non-unit dim in the output, if present:
-  //   - 0            for shapes like [N, 1, 1]
-  //   - outputRank-1 for shapes like [1, 1, N]
-  //
-  // For the all-ones case, treat it like the "non-unit on the right" case.
-  unsigned nonUnitDimPos =
-      (outputNonUnitCount == 1 && outputTy.getDimSize(0) != 1) ? 0
-                                                               : outputRank - 1;
-
-  if (outputRank >= inputRank) {
-    // Rank expansion case.
-    //
-    // The rewrite keeps only inputRank indices. Any non-zero index in an
-    // expanded unit dimension that would be discarded makes the fold invalid.
-    if (nonUnitDimPos == 0) {
-      // Expansion on the right: keep the leftmost inputRank indices.
-      // Therefore any non-zero index in the suffix would be lost.
-      for (unsigned pos : nonZeroIdxPositions) {
-        if (pos >= inputRank)
-          return false;
-      }
-    } else {
-      // Expansion on the left: keep the rightmost inputRank indices.
-      // Therefore any non-zero index in the prefix would be lost.
-      unsigned firstValidPos = outputRank - inputRank;
-      for (unsigned pos : nonZeroIdxPositions) {
-        if (pos < firstValidPos)
-          return false;
-      }
-    }
+    // Bail out early for the all-ones case.
+    if (!inputNonUnitDim->Exists)
+      return false;
+
+    // FIXME: This should be ensured by the memref.load semantics.
+    if (!isConstantIndexInBounds(idx, rcOutputTy.getDimSize(pos)))
+      return false;
+
+    nonZeroIdxPositions.push_back(pos);
   }
 
-  return true;
+  // During rank expansion, the rewrite drops the extra unit-dimension indices.
+  // That is only semantics-preserving if every dropped index is zero.
+  // This check is only relevant for expansions with a non-unit dimension.
+  if (rcOutputRank < rcInputRank || !inputNonUnitDim->Exists)
+    return true;
+
+  // The rewrite keeps either a prefix or suffix of length `rcInputRank`.
+  // Any non-zero index outside the preserved slice would be discarded.
+  if (outputNonUnitDim->isOnLeft)
+    return llvm::none_of(nonZeroIdxPositions,
+                         [&](unsigned pos) { return pos >= rcInputRank; });
+
+  unsigned firstKeptPos = rcOutputRank - rcInputRank;
+  return llvm::none_of(nonZeroIdxPositions,
+                       [&](unsigned pos) { return pos < firstKeptPos; });
 }
 
-struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
+struct RewriteLoadFromReinterpretCast
+    : public OpRewritePattern<memref::LoadOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
@@ -349,9 +367,11 @@ struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
     if (!rc)
       return failure();
 
-    // This fold is only correct for the narrow "pure rank reshape of a single
-    // logical dimension" cases accepted by isPureRankReshape().
-    if (!isPureRankReshape(rc, op))
+    // This rewrite is only correct for the narrow "pure rank expansion or
+    // collapsing of a single logical dimension" cases accepted by these two
+    // checks.
+    if (!isPureRankExpansionOrCollapsingRC(rc) ||
+        !areIndicesForUnitDimsInBounds(op))
       return failure();
 
     auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
@@ -362,46 +382,51 @@ struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
 
     SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
     SmallVector<Value> rcInputIdxs;
+    rcInputIdxs.reserve(rcInputRank);
 
-    // The fold only supports reshapes with at most one non-unit dimension,
-    // located at the left or right boundary.
+    // The rewrite only supports reinterpret_casts with at most one non-unit
+    // dimension, located at the left or right boundary.
     //
-    // The higher-rank side tells which side the reshape has expanded/collapsed.
+    // The higher-rank side tells which side the reinterpret_cast has
+    // expanded/collapsed.
     //
     //   expansion: rcOutput has the higher rank
-    //   collapse : rcInput has the higher rank
+    //   collapsing : rcInput has the higher rank
     //
     // Example:
     //   memref<999>     -> memref<1x1x999>   : extra dims to the left
     //   memref<999x1x1> -> memref<999>       : extra dims to the right
     MemRefType expandedTy =
         rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
-    bool nonUnitOnLeft = expandedTy.getDimSize(0) != 1;
+    std::optional<SingleNonUnitDimInfo> expandedNonUnitDim =
+        getSingleNonUnitDimInfo(expandedTy);
+    assert(expandedNonUnitDim && "expected a single boundary non-unit dim");
+    bool keepLeadingIndices = expandedNonUnitDim->isOnLeft;
 
     if (rcOutputRank >= rcInputRank) {
       // Rank expansion:
-      //   memref<N>   -> memref<1x1xN>   : keep the last rcInputRank indices
-      //   memref<N>   -> memref<Nx1x1>   : keep the first rcInputRank indices
+      //   memref<N>     -> memref<1x1xN> : keep the last rcInputRank indices
+      //   memref<N>     -> memref<Nx1x1> : keep the first rcInputRank indices
+      //   memref<1>     -> memref<1x1x1> : all indices are zero
       //
-      // Any discarded indices are known to be zero from isPureRankReshape().
-      if (nonUnitOnLeft) {
-        for (int64_t dim = 0; dim < rcInputRank; ++dim)
-          rcInputIdxs.push_back(idxs[dim]);
-      } else {
-        for (int64_t dim = 0; dim < rcInputRank; ++dim)
-          rcInputIdxs.push_back(idxs[rcOutputRank - rcInputRank + dim]);
-      }
+      // Any discarded indices are known to be zero from
+      // areIndicesForUnitDimsInBounds().
+      int64_t firstKeptPos =
+          keepLeadingIndices ? 0 : rcOutputRank - rcInputRank;
+      rcInputIdxs.append(idxs.begin() + firstKeptPos,
+                         idxs.begin() + firstKeptPos + rcInputRank);
     } else {
-      // Rank collapse:
-      //   memref<1x1xN> -> memref<N>      : reinsert leading zeros
-      //   memref<Nx1x1> -> memref<N>      : reinsert trailing zeros
+      // Rank collapsing:
+      //   memref<1x1xN> -> memref<N>     : reinsert leading zeros
+      //   memref<Nx1x1> -> memref<N>     : reinsert trailing zeros
+      //   memref<1x1x1> -> memref<1>     : all indices are zero
       //
-      // The collapsed-away dimensions are unit dims, so readding them with
+      // The collapsed-away dimensions are unit dims, so re-adding them with
       // zero indices preserves semantics.
       Value c0 = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
       int64_t rankDiff = rcInputRank - rcOutputRank;
 
-      if (nonUnitOnLeft) {
+      if (keepLeadingIndices) {
         rcInputIdxs.append(idxs.begin(), idxs.end());
         rcInputIdxs.append(rankDiff, c0);
       } else {
@@ -410,10 +435,8 @@ struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
       }
     }
 
-    // Sanity check: rewritten load must index the source memref with exactly
-    // as many indices as the rank.
-    if ((int64_t)rcInputIdxs.size() != rcInputRank)
-      return failure();
+    assert(rcInputIdxs.size() == static_cast<size_t>(rcInputRank) &&
+           "Incorrect number of indices!");
 
     auto rcInput = rc.getSource();
     // If the only user of rc is the current Op (which is about to be erased),
@@ -447,7 +470,8 @@ struct ElideReinterpretCastPass
       auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
       if (!rc)
         return true;
-      return !isPureRankReshape(rc, op);
+      return !(isPureRankExpansionOrCollapsingRC(rc) &&
+               areIndicesForUnitDimsInBounds(op));
     });
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
@@ -460,6 +484,6 @@ struct ElideReinterpretCastPass
 
 void mlir::memref::populateElideReinterpretCastPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CopyToScalarLoadAndStore, FoldReinterpretCastLoad>(
+  patterns.add<CopyToScalarLoadAndStore, RewriteLoadFromReinterpretCast>(
       patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index 5733c97ea8f3b..944751dbb005f 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -179,8 +179,8 @@ func.func private @negative_concat_strided_base(%src: memref<1x1xf32>,
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_rank_change(
-func.func private @negative_reshape_rank_change(%src : memref<2x3xf32>,
+// CHECK-LABEL: func.func private @negative_rank_change(
+func.func private @negative_rank_change(%src : memref<2x3xf32>,
   %dst : memref<6xf32>) {
   // CHECK:      %reinterpret_cast = memref.reinterpret_cast %arg1
   %reinterpret_cast = memref.reinterpret_cast %dst
@@ -229,42 +229,38 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
 // Positive tests
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-LABEL: func.func private @expand_scalar(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
-func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+func.func private @expand_scalar(%src : memref<1xi64>) {
   // CHECK:       %[[C0:.*]] = arith.constant 0 : index
-  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
       to memref<1x1x1xi64>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
-  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xi64>
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-LABEL: func.func private @collapse_scalar(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
-func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+func.func private @collapse_scalar(%src : memref<1x1x1xi64>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
-  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
       to memref<1x1xi64>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
-  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0] : memref<1x1xi64>
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-LABEL: func.func private @expand_left_vector(
 // CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
-func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+func.func private @expand_left_vector(%src : memref<999xi64>) {
   // CHECK:       %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK-NOT:   memref.reinterpret_cast
@@ -276,9 +272,9 @@ func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-LABEL: func.func private @collapse_left_vector(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xi64>) {
-func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
   %c1 = arith.constant 1 : index
@@ -291,9 +287,9 @@ func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-LABEL: func.func private @expand_left_inner_unit_dims(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
-func.func private @reshape_expand_left_inner_unit_dims(
+func.func private @expand_left_inner_unit_dims(
     %src : memref<1x108xf32>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
@@ -303,15 +299,15 @@ func.func private @reshape_expand_left_inner_unit_dims(
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
       : memref<1x108xf32> to memref<1x1x1x108xf32>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
-  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0, %c1]
     : memref<1x1x1x108xf32>
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-LABEL: func.func private @collapse_left_inner_unit_dims(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @reshape_collapse_left_inner_unit_dims(
+func.func private @collapse_left_inner_unit_dims(
     %src : memref<1x1x1x100xf32>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
@@ -327,9 +323,9 @@ func.func private @reshape_collapse_left_inner_unit_dims(
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-LABEL: func.func private @expand_right_vector(
 // CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
-func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+func.func private @expand_right_vector(%src : memref<999xi64>) {
   // CHECK:       %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK-NOT:   memref.reinterpret_cast
@@ -342,9 +338,9 @@ func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-LABEL: func.func private @collapse_right_vector(
 // CHECK-SAME:    %[[SRC:.*]]: memref<999x1x1xi64>) {
-func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+func.func private @collapse_right_vector(%src : memref<999x1x1xi64>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
   %c1 = arith.constant 1 : index
@@ -357,9 +353,9 @@ func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-LABEL: func.func private @expand_right_inner_unit_dims(
 // CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
-func.func private @reshape_expand_right_inner_unit_dims(
+func.func private @expand_right_inner_unit_dims(
     %src : memref<108x1xf32>) {
   // CHECK:       %[[C0:.*]] = arith.constant 0 : index
   // CHECK:       %[[C1:.*]] = arith.constant 1 : index
@@ -369,15 +365,15 @@ func.func private @reshape_expand_right_inner_unit_dims(
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
       : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
-  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<108x1xf32>
+  %0 = memref.load %reinterpret_cast[%c1, %c0, %c0, %c0]
     : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
   return
 }
 
-// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-LABEL: func.func private @collapse_right_inner_unit_dims(
 // CHECK-SAME:    %[[SRC:.*]]: memref<100x1x1x1xf32>) {
-func.func private @reshape_collapse_right_inner_unit_dims(
+func.func private @collapse_right_inner_unit_dims(
     %src : memref<100x1x1x1xf32>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
@@ -394,17 +390,35 @@ func.func private @reshape_collapse_right_inner_unit_dims(
   return
 }
 
+// CHECK-LABEL: func.func private @negative_diff_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_diff_non_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C98:.*]] = arith.constant 98 : index
+  %c0 = arith.constant 0 : index
+  %c98 = arith.constant 98 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 99], strides: [99, 1]
+      : memref<1x1x1x100xf32> to memref<1x99xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C98]]] : memref<1x1x1x100xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c98] : memref<1x99xf32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Negative tests (must NOT rewrite)
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-LABEL: func.func private @negative_nonzero_offset(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
-func.func private @negative_reshape_nonzero_offset(
+func.func private @negative_nonzero_offset(
     %src : memref<1xi64>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
       to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
@@ -414,13 +428,13 @@ func.func private @negative_reshape_nonzero_offset(
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_dynamic_shape(
+// CHECK-LABEL: func.func private @negative_dynamic_shape(
 // CHECK-SAME:   %[[DIM:[A-Za-z][A-Za-z0-9-]*]]: index
 // CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<?xi64>
-func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
+func.func private @negative_dynamic_shape(%dim : index, %i : index,
     %src : memref<?xi64>) {
   %c0 = arith.constant 0 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, %[[DIM]]], strides: [1, 1] : memref<?xi64> to memref<1x?xi64>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, %dim], strides: [1, 1]
       : memref<?xi64> to memref<1x?xi64>
@@ -429,15 +443,15 @@ func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_dynamic_stride(
+// CHECK-LABEL: func.func private @negative_dynamic_stride(
 // CHECK-SAME:   %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
 // CHECK-SAME:   %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
 // CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xi64>
-func.func private @negative_reshape_dynamic_stride(%stride0: index,
+func.func private @negative_dynamic_stride(%stride0: index,
     %stride1: index, %src : memref<1x108xi64>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1], strides: [%[[STR0]], %[[STR1]]] : memref<1x108xi64> to memref<1x1xi64, strided<[?, ?]>>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
     : memref<1x108xi64>
@@ -448,13 +462,13 @@ func.func private @negative_reshape_dynamic_stride(%stride0: index,
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_multiple_non_unit_dims(
+// CHECK-LABEL: func.func private @negative_multiple_non_unit_dims(
 // CHECK-SAME:    %[[SRC:.*]]: memref<2x1x1x100xf32>) {
-func.func private @negative_reshape_multiple_non_unit_dims(
-    %src : memref<2x1x1x100xf32>) {
+func.func private @negative_multiple_non_unit_dims(
+  %src : memref<2x1x1x100xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [2, 100], strides: [100, 1] : memref<2x1x1x100xf32> to memref<2x100xf32>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [2, 100], strides: [100, 1]
       : memref<2x1x1x100xf32> to memref<2x100xf32>
@@ -463,46 +477,73 @@ func.func private @negative_reshape_multiple_non_unit_dims(
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_diff_non_unit_dims(
+// CHECK-LABEL: func.func private @negative_inner_non_unit_dims(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @negative_reshape_diff_non_unit_dims(
+func.func private @negative_inner_non_unit_dims(
     %src : memref<1x1x1x100xf32>) {
-  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 101], strides: [101, 1] : memref<1x1x1x100xf32> to memref<1x101xf32>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 101], strides: [101, 1]
-      : memref<1x1x1x100xf32> to memref<1x101xf32>
+    to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
+      : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
   // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x101xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
+    strided<[100, 1, 100]>>
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_inner_non_unit_dims(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @negative_reshape_inner_non_unit_dims(
-    %src : memref<1x1x1x100xf32>) {
+// CHECK-LABEL: func.func private @negative_expand_out_of_bounds(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_expand_out_of_bounds(%src : memref<1xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100] : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
-      : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64>
   // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
-    strided<[100, 1, 100]>>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_collapse_out_of_bounds(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @negative_collapse_out_of_bounds(%src : memref<1x1x1xi64>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+      to memref<1x1xi64>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_expand_negative_index(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_expand_negative_index(%src : memref<1xi64>) {
+  %c0 = arith.constant 0 : index
+  %cneg1 = arith.constant -1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %cneg1] : memref<1x1x1xi64>
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_expand_left_discarded_indices(
+// CHECK-LABEL: func.func private @negative_expand_left_discarded_indices(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
-func.func private @negative_reshape_expand_left_discarded_indices(
+func.func private @negative_expand_left_discarded_indices(
     %src : memref<1x108xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1] : memref<1x108xf32> to memref<1x1x1x108xf32>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
       : memref<1x108xf32> to memref<1x1x1x108xf32>
@@ -512,13 +553,13 @@ func.func private @negative_reshape_expand_left_discarded_indices(
   return
 }
 
-// CHECK-LABEL: func.func private @negative_reshape_expand_right_discarded_indices(
+// CHECK-LABEL: func.func private @negative_expand_right_discarded_indices(
 // CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
-func.func private @negative_reshape_expand_right_discarded_indices(
+func.func private @negative_expand_right_discarded_indices(
     %src : memref<108x1xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108] : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
       : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>

>From 746d3e376844a7536cf099d3fe2a6fea432533b0 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Tue, 21 Apr 2026 12:21:12 +0200
Subject: [PATCH 3/4] Address second round of comments

---
 .../Transforms/ElideReinterpretCast.cpp       | 164 ++++++--------
 .../MemRef/elide-reinterpret-cast.mlir        | 213 +++++++-----------
 2 files changed, 151 insertions(+), 226 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index 37b2e70e43f54..06d2f48b0e551 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -198,64 +198,67 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
-static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
-
-static std::optional<int64_t> getConstantIndex(Value v) {
-  if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
-    return cst.value();
-  return std::nullopt;
-}
-
+/// Describes the unique non-unit dimension of a MemRef shape.
+///
+/// This helper is only used for shapes that have at most one non-unit
+/// dimension. `exists` is false for all-ones shapes. Otherwise, `isOnLeft`
+/// indicates whether the non-unit dimension is on the left boundary.
+///
+/// If `exists` is true and `isOnLeft` is false, the non-unit dimension is on
+/// the right boundary. Rank-1 non-unit MemRefs are treated as matching both
+/// boundaries and callers that care about the right boundary must account for
+/// that from the MemRef type.
 struct SingleNonUnitDimInfo {
-  bool Exists = false;
-  bool isOnLeft = true;
-  bool isOnRight = true;
+  bool exists = false;
+  bool isOnLeft = false;
 };
 
-/// Returns information about a MemRef with at most one non-unit dimension.
+/// Returns information about a MemRef if it contains at most one non-unit
+/// dimension.
 ///
 /// The single non-unit dimension, if present, must be on the left or right
-/// boundary. Rank-1 non-unit memrefs are treated as being on both boundaries.
+/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
 static std::optional<SingleNonUnitDimInfo>
 getSingleNonUnitDimInfo(MemRefType type) {
   ArrayRef<int64_t> shape = type.getShape();
   int64_t nonUnitCount =
       llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+  // Return default values if missing nonUnitDim
   if (nonUnitCount == 0)
     return SingleNonUnitDimInfo{};
+  // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
   if (nonUnitCount > 1)
     return std::nullopt;
 
   bool isOnLeft = shape.front() != 1;
-  bool isOnRight = shape.back() != 1;
-  if (!isOnLeft && !isOnRight)
+  // Return no info if MemRef breaks nonUnitDim requirements (nonUnitDim in
+  // non-boundary pos)
+  if (!isOnLeft && shape.back() == 1)
     return std::nullopt;
 
-  return SingleNonUnitDimInfo{/*Exists=*/true, isOnLeft, isOnRight};
+  return SingleNonUnitDimInfo{/*exists=*/true, isOnLeft};
 }
 
 static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
   ArrayRef<int64_t> offsets = rc.getStaticOffsets();
   // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
   // only a single offset. That should be fixed at the op definition level.
-  return offsets.size() == 1 && !ShapedType::isDynamic(offsets[0]) &&
-         offsets[0] == 0;
+  assert(offsets.size() == 1 && "Expecting single offset");
+  return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
 }
 
-static bool isConstantIndexInBounds(Value idx, int64_t upperBound) {
-  if (isConstZero(idx))
-    return true;
+static std::optional<int64_t> getConstantIndex(Value v) {
+  if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+    return cst.value();
+  return std::nullopt;
+}
 
+static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
+                                                 int64_t upperBound) {
   std::optional<int64_t> idxVal = getConstantIndex(idx);
-  return idxVal && *idxVal >= 0 && *idxVal < upperBound;
+  return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
 }
 
-/// Checks for pure rank expansion/collapsing of a single logical dimension:
-///   - all metadata is static
-///   - offset is 0
-///   - source/result each have at most one non-unit dim
-///   - if a non-unit dim exists, it is at the left or right boundary
-///
 /// Examples accepted by this shape restriction:
 ///   memref<999xf32>       <-> memref<1x1x999xf32>
 ///   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
@@ -267,9 +270,9 @@ static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
   auto inputTy = cast<MemRefType>(rc.getSource().getType());
   auto outputTy = cast<MemRefType>(rc.getResult().getType());
 
-  // This check assumes the rewrite relies on "index re-use" and misses "index
-  // re-write/adjustment". Thus, storage shift and statically unknown offsets
-  // are rejected.
+  // This rewrite assumes "index re-use" and misses "index
+  // re-write/adjustment" logic, hence the requirement for the offset to be 0.
+  // Thus, storage shift and statically unknown offsets are rejected.
   if (!hasStaticZeroOffset(rc))
     return false;
 
@@ -285,75 +288,46 @@ static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
       getSingleNonUnitDimInfo(inputTy);
   std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
       getSingleNonUnitDimInfo(outputTy);
+  // Bail out early if nonUnitDims don't follow rewrite assumptions.
   if (!inputNonUnitDim || !outputNonUnitDim)
     return false;
 
-  // The source and result must either both be all-ones, or both have a single
-  // non-unit dimension.
-  if (inputNonUnitDim->Exists != outputNonUnitDim->Exists)
+  // The source and result must either both have a single non-unit dimension
+  // or both be all-ones.
+  if (inputNonUnitDim->exists != outputNonUnitDim->exists)
     return false;
-  if (!inputNonUnitDim->Exists)
+  if (!inputNonUnitDim->exists)
     return true;
 
-  // If there is a non-unit dimension, it must be preserved on the same
-  // boundary. Rank-1 memrefs are accepted against either boundary.
-  return (inputNonUnitDim->isOnLeft && outputNonUnitDim->isOnLeft) ||
-         (inputNonUnitDim->isOnRight && outputNonUnitDim->isOnRight);
+  // The preserved non-unit dimension must have the same size.
+  if (inputTy.getDimSize(inputNonUnitDim->isOnLeft ? 0
+                                                   : inputTy.getRank() - 1) !=
+      outputTy.getDimSize(outputNonUnitDim->isOnLeft ? 0
+                                                     : outputTy.getRank() - 1))
+    return false;
+
+  // If both sides have rank > 1, the non-unit dimension must be on the same
+  // boundary. Rank-1 MemRefs are accepted against either boundary.
+  if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
+      inputNonUnitDim->isOnLeft != outputNonUnitDim->isOnLeft)
+    return false;
+
+  return true;
 }
 
-/// Checks whether load indices corresponding to unit dims in the source
-/// MemRef are all 0, i.e. in-bounds.
-/// Returns false for out-of-bounds indices or non-constant indices.
-static bool areIndicesForUnitDimsInBounds(memref::LoadOp load) {
+/// Checks statically known indices accessed by a load from a pure rank
+/// expansion/collapsing to ensure in-bounds only access. Dynamic indices are
+/// accepted.
+static bool areIndicesInBounds(memref::LoadOp load) {
   auto rc = load.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
-  auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
   auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
-  int64_t rcInputRank = rcInputTy.getRank();
-  int64_t rcOutputRank = rcOutputTy.getRank();
 
-  std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
-      getSingleNonUnitDimInfo(rcInputTy);
-  std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
-      getSingleNonUnitDimInfo(rcOutputTy);
-  if (!inputNonUnitDim || !outputNonUnitDim)
-    return false;
-
-  SmallVector<unsigned> nonZeroIdxPositions;
-  nonZeroIdxPositions.reserve(load.getIndices().size());
-
-  // Check if non-zero indices are out-of-bounds.
-  // Only care about indices corresponding to the load from the
-  // reinterpret_cast result.
   for (auto [pos, idx] : llvm::enumerate(load.getIndices())) {
-    if (isConstZero(idx))
-      continue;
-
-    // Bail out early for the all-ones case.
-    if (!inputNonUnitDim->Exists)
-      return false;
-
     // FIXME: This should be ensured by the memref.load semantics.
-    if (!isConstantIndexInBounds(idx, rcOutputTy.getDimSize(pos)))
+    if (isConstantIndexExplicitlyOutOfBounds(idx, rcOutputTy.getDimSize(pos)))
       return false;
-
-    nonZeroIdxPositions.push_back(pos);
   }
-
-  // During rank expansion, the rewrite drops the extra unit-dimension indices.
-  // That is only semantics-preserving if every dropped index is zero.
-  // This check is only relevant for expansions with a non-unit dimension.
-  if (rcOutputRank < rcInputRank || !inputNonUnitDim->Exists)
-    return true;
-
-  // The rewrite keeps either a prefix or suffix of length `rcInputRank`.
-  // Any non-zero index outside the preserved slice would be discarded.
-  if (outputNonUnitDim->isOnLeft)
-    return llvm::none_of(nonZeroIdxPositions,
-                         [&](unsigned pos) { return pos >= rcInputRank; });
-
-  unsigned firstKeptPos = rcOutputRank - rcInputRank;
-  return llvm::none_of(nonZeroIdxPositions,
-                       [&](unsigned pos) { return pos < firstKeptPos; });
+  return true;
 }
 
 struct RewriteLoadFromReinterpretCast
@@ -365,14 +339,15 @@ struct RewriteLoadFromReinterpretCast
                                 PatternRewriter &rewriter) const override {
     auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
     if (!rc)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "target is not a memref.reinterpret_cast");
+    if (!isPureRankExpansionOrCollapsingRC(rc))
+      return rewriter.notifyMatchFailure(
+          op, "reinterpret_cast is not a pure rank expansion or collapsing of "
+              "a single dimension");
 
-    // This rewrite is only correct for the narrow "pure rank expansion or
-    // collapsing of a single logical dimension" cases accepted by these two
-    // checks.
-    if (!isPureRankExpansionOrCollapsingRC(rc) ||
-        !areIndicesForUnitDimsInBounds(op))
-      return failure();
+    assert(areIndicesInBounds(op) &&
+           "load from reinterpret_cast indexes out of bounds!");
 
     auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
     auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
@@ -410,7 +385,7 @@ struct RewriteLoadFromReinterpretCast
       //   memref<1>     -> memref<1x1x1> : all indices are zero
       //
       // Any discarded indices are known to be zero from
-      // areIndicesForUnitDimsInBounds().
+      // areIndicesInBounds().
       int64_t firstKeptPos =
           keepLeadingIndices ? 0 : rcOutputRank - rcInputRank;
       rcInputIdxs.append(idxs.begin() + firstKeptPos,
@@ -470,8 +445,7 @@ struct ElideReinterpretCastPass
       auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
       if (!rc)
         return true;
-      return !(isPureRankExpansionOrCollapsingRC(rc) &&
-               areIndicesForUnitDimsInBounds(op));
+      return !isPureRankExpansionOrCollapsingRC(rc);
     });
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index 944751dbb005f..9b563abadb3aa 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -229,6 +229,8 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
 // Positive tests
 //===----------------------------------------------------------------------===//
 
+/// For rank-1 MemRefs, expansion/collapsing may be considered on either side.
+
 // CHECK-LABEL: func.func private @expand_scalar(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
 func.func private @expand_scalar(%src : memref<1xi64>) {
@@ -253,7 +255,7 @@ func.func private @collapse_scalar(%src : memref<1x1x1xi64>) {
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
       to memref<1x1xi64>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C0]]] : memref<1x1x1xi64>
   %0 = memref.load %reinterpret_cast[%c0, %c0] : memref<1x1xi64>
   return
 }
@@ -272,6 +274,22 @@ func.func private @expand_left_vector(%src : memref<999xi64>) {
   return
 }
 
+// CHECK-LABEL: func.func private @expand_left_vector_dynamic_index(
+// CHECK-SAME:    %[[I:.*]]: index
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @expand_left_vector_dynamic_index(%i : index,
+    %src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<999xi64> to memref<1x1x999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[I]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %i] : memref<1x1x999xi64>
+  return
+}
+
 // CHECK-LABEL: func.func private @collapse_left_vector(
 // CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xi64>) {
 func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
@@ -287,28 +305,28 @@ func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
   return
 }
 
-// CHECK-LABEL: func.func private @expand_left_inner_unit_dims(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
-func.func private @expand_left_inner_unit_dims(
-    %src : memref<1x108xf32>) {
+// CHECK-LABEL: func.func private @partial_expand_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x999xf32>) {
+func.func private @partial_expand_left_vector(
+    %src : memref<1x999xf32>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
-      : memref<1x108xf32> to memref<1x1x1x108xf32>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<1x108xf32>
-  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0, %c1]
-    : memref<1x1x1x108xf32>
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<1x999xf32> to memref<1x1x999xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<1x999xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1]
+    : memref<1x1x999xf32>
   return
 }
 
-// CHECK-LABEL: func.func private @collapse_left_inner_unit_dims(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @collapse_left_inner_unit_dims(
-    %src : memref<1x1x1x100xf32>) {
+// CHECK-LABEL: func.func private @partial_collapse_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xf32>) {
+func.func private @partial_collapse_left_vector(
+    %src : memref<1x1x999xf32>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
@@ -316,10 +334,10 @@ func.func private @collapse_left_inner_unit_dims(
   %c1 = arith.constant 1 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 100], strides: [100, 1]
-      : memref<1x1x1x100xf32> to memref<1x100xf32>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1x100xf32>
-  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x100xf32>
+    to offset: [0], sizes: [1, 999], strides: [999, 1]
+      : memref<1x1x999xf32> to memref<1x999xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x999xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x999xf32>
   return
 }
 
@@ -353,58 +371,55 @@ func.func private @collapse_right_vector(%src : memref<999x1x1xi64>) {
   return
 }
 
-// CHECK-LABEL: func.func private @expand_right_inner_unit_dims(
-// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
-func.func private @expand_right_inner_unit_dims(
-    %src : memref<108x1xf32>) {
-  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
-  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
+// CHECK-LABEL: func.func private @collapse_right_vector_dynamic_index(
+// CHECK-SAME:    %[[I:.*]]: index
+// CHECK-SAME:    %[[SRC:.*]]: memref<999x1x1xi64>) {
+func.func private @collapse_right_vector_dynamic_index(%i : index,
+    %src : memref<999x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
-      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<108x1xf32>
-  %0 = memref.load %reinterpret_cast[%c1, %c0, %c0, %c0]
-    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<999x1x1xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[I]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%i] : memref<999xi64>
   return
 }
 
-// CHECK-LABEL: func.func private @collapse_right_inner_unit_dims(
-// CHECK-SAME:    %[[SRC:.*]]: memref<100x1x1x1xf32>) {
-func.func private @collapse_right_inner_unit_dims(
-    %src : memref<100x1x1x1xf32>) {
-  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
-  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-LABEL: func.func private @partial_expand_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999x1xf32>) {
+func.func private @partial_expand_right_vector(
+    %src : memref<999x1xf32>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [100, 1], strides: [1, 100]
-      : memref<100x1x1x1xf32> to memref<100x1xf32, strided<[1, 100]>>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]], %[[C0_0]]] : memref<100x1x1x1xf32>
-  %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<100x1xf32,
-    strided<[1, 100]>>
+    to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
+      : memref<999x1xf32> to memref<999x1x1xf32, strided<[1, 999, 999]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<999x1xf32>
+  %0 = memref.load %reinterpret_cast[%c1, %c0, %c0]
+    : memref<999x1x1xf32, strided<[1, 999, 999]>>
   return
 }
 
-// CHECK-LABEL: func.func private @negative_diff_non_unit_dims(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @negative_diff_non_unit_dims(
-    %src : memref<1x1x1x100xf32>) {
+// CHECK-LABEL: func.func private @partial_collapse_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999x1x1xf32>) {
+func.func private @partial_collapse_right_vector(
+    %src : memref<999x1x1xf32>) {
   // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
-  // CHECK-DAG:   %[[C98:.*]] = arith.constant 98 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
-  %c98 = arith.constant 98 : index
+  %c1 = arith.constant 1 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 99], strides: [99, 1]
-      : memref<1x1x1x100xf32> to memref<1x99xf32>
-  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C98]]] : memref<1x1x1x100xf32>
-  %0 = memref.load %reinterpret_cast[%c0, %c98] : memref<1x99xf32>
+    to offset: [0], sizes: [999, 1], strides: [1, 999]
+      : memref<999x1x1xf32> to memref<999x1xf32, strided<[1, 999]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]]] : memref<999x1x1xf32>
+  %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<999x1xf32,
+    strided<[1, 999]>>
   return
 }
 
@@ -418,7 +433,7 @@ func.func private @negative_nonzero_offset(
     %src : memref<1xi64>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
       to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
@@ -429,12 +444,11 @@ func.func private @negative_nonzero_offset(
 }
 
 // CHECK-LABEL: func.func private @negative_dynamic_shape(
-// CHECK-SAME:   %[[DIM:[A-Za-z][A-Za-z0-9-]*]]: index
 // CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<?xi64>
 func.func private @negative_dynamic_shape(%dim : index, %i : index,
     %src : memref<?xi64>) {
   %c0 = arith.constant 0 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, %dim], strides: [1, 1]
       : memref<?xi64> to memref<1x?xi64>
@@ -444,14 +458,12 @@ func.func private @negative_dynamic_shape(%dim : index, %i : index,
 }
 
 // CHECK-LABEL: func.func private @negative_dynamic_stride(
-// CHECK-SAME:   %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
-// CHECK-SAME:   %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
 // CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xi64>
 func.func private @negative_dynamic_stride(%stride0: index,
     %stride1: index, %src : memref<1x108xi64>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
     : memref<1x108xi64>
@@ -468,7 +480,7 @@ func.func private @negative_multiple_non_unit_dims(
   %src : memref<2x1x1x100xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [2, 100], strides: [100, 1]
       : memref<2x1x1x100xf32> to memref<2x100xf32>
@@ -483,7 +495,7 @@ func.func private @negative_inner_non_unit_dims(
     %src : memref<1x1x1x100xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
       : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
@@ -493,78 +505,17 @@ func.func private @negative_inner_non_unit_dims(
   return
 }
 
-// CHECK-LABEL: func.func private @negative_expand_out_of_bounds(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
-func.func private @negative_expand_out_of_bounds(%src : memref<1xi64>) {
-  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
-  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
-  %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
-      to memref<1x1x1xi64>
-  // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
-  return
-}
-
-// CHECK-LABEL: func.func private @negative_collapse_out_of_bounds(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
-func.func private @negative_collapse_out_of_bounds(%src : memref<1x1x1xi64>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
-  %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
-      to memref<1x1xi64>
-  // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
-  return
-}
-
-// CHECK-LABEL: func.func private @negative_expand_negative_index(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
-func.func private @negative_expand_negative_index(%src : memref<1xi64>) {
-  %c0 = arith.constant 0 : index
-  %cneg1 = arith.constant -1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
-  %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
-      to memref<1x1x1xi64>
-  // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c0, %cneg1] : memref<1x1x1xi64>
-  return
-}
-
-// CHECK-LABEL: func.func private @negative_expand_left_discarded_indices(
-// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
-func.func private @negative_expand_left_discarded_indices(
-    %src : memref<1x108xf32>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
-  %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
-      : memref<1x108xf32> to memref<1x1x1x108xf32>
-  // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
-    : memref<1x1x1x108xf32>
-  return
-}
-
-// CHECK-LABEL: func.func private @negative_expand_right_discarded_indices(
-// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
-func.func private @negative_expand_right_discarded_indices(
-    %src : memref<108x1xf32>) {
+// CHECK-LABEL: func.func private @negative_diff_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_diff_non_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
+  %c98 = arith.constant 98 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
-      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+    to offset: [0], sizes: [1, 99], strides: [99, 1]
+      : memref<1x1x1x100xf32> to memref<1x99xf32>
   // CHECK:       memref.load %[[RC]]
-  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
-    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  %0 = memref.load %reinterpret_cast[%c0, %c98] : memref<1x99xf32>
   return
 }

>From 9a9df3e6bef1c9af12d36710506adb265b0571b3 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Thu, 23 Apr 2026 12:26:18 +0200
Subject: [PATCH 4/4] Address third round of comments

---
 .../Transforms/ElideReinterpretCast.cpp       | 113 +++++++++++-------
 .../MemRef/elide-reinterpret-cast.mlir        |  41 ++++---
 2 files changed, 88 insertions(+), 66 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index 06d2f48b0e551..68ae39b7e834d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -198,19 +198,18 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
-/// Describes the unique non-unit dimension of a MemRef shape.
+/// Captures info about MemRefs that are effectively 1D (the leading or trailing
+/// dims are all 1). The only accepted non-unit dim is either the leading of the
+/// trailing dim.
 ///
-/// This helper is only used for shapes that have at most one non-unit
-/// dimension. `exists` is false for all-ones shapes. Otherwise, `isOnLeft`
-/// indicates whether the non-unit dimension is on the left boundary.
+/// Examples:
+/// memref<1x1x4xf32>, memref<4x1x1xf32>, memref<1x1x1xf32>
 ///
-/// If `exists` is true and `isOnLeft` is false, the non-unit dimension is on
-/// the right boundary. Rank-1 non-unit MemRefs are treated as matching both
-/// boundaries and callers that care about the right boundary must account for
-/// that from the MemRef type.
-struct SingleNonUnitDimInfo {
-  bool exists = false;
-  bool isOnLeft = false;
+struct ShapeInfoFor1DMemRef {
+  // Are all dims == 1? `false` means that there is exactly one dim != 1.
+  bool allOnes = true;
+  // If there is a non-unit boundary dim, is it the leading or the trailing dim?
+  bool isLeadingDimNonUnit = false;
 };
 
 /// Returns information about a MemRef if it contains at most one non-unit
@@ -218,25 +217,24 @@ struct SingleNonUnitDimInfo {
 ///
 /// The single non-unit dimension, if present, must be on the left or right
 /// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
-static std::optional<SingleNonUnitDimInfo>
-getSingleNonUnitDimInfo(MemRefType type) {
+static std::optional<ShapeInfoFor1DMemRef>
+getShapeInfoFor1DMemRef(MemRefType type) {
   ArrayRef<int64_t> shape = type.getShape();
   int64_t nonUnitCount =
       llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
   // Return default values if missing nonUnitDim
   if (nonUnitCount == 0)
-    return SingleNonUnitDimInfo{};
+    return ShapeInfoFor1DMemRef{};
   // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
   if (nonUnitCount > 1)
     return std::nullopt;
-
-  bool isOnLeft = shape.front() != 1;
   // Return no info if MemRef breaks nonUnitDim requirements (nonUnitDim in
   // non-boundary pos)
-  if (!isOnLeft && shape.back() == 1)
+  if (shape.front() == 1 && shape.back() == 1)
     return std::nullopt;
 
-  return SingleNonUnitDimInfo{/*exists=*/true, isOnLeft};
+  return ShapeInfoFor1DMemRef{/*allOnes=*/false,
+                              /*isLeadingDimNonUnit=*/shape.front() != 1};
 }
 
 static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
@@ -250,11 +248,13 @@ static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
 static std::optional<int64_t> getConstantIndex(Value v) {
   if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
     return cst.value();
+  // Non-constant and dynamic indices
   return std::nullopt;
 }
 
 static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
                                                  int64_t upperBound) {
+  // Only statically known `arith.constant` indices are checked here.
   std::optional<int64_t> idxVal = getConstantIndex(idx);
   return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
 }
@@ -270,46 +270,49 @@ static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
   auto inputTy = cast<MemRefType>(rc.getSource().getType());
   auto outputTy = cast<MemRefType>(rc.getResult().getType());
 
-  // This rewrite assumes "index re-use" and misses "index
-  // re-write/adjustment" logic, hence the requirement for the offset to be 0.
-  // Thus, storage shift and statically unknown offsets are rejected.
+  // Only zero, statically known offsets are accepted. Non-zero or dynamic
+  // offsets would require reasoning about storage shifts in the underlying
+  // reinterpret_cast, which this helper does not model.
   if (!hasStaticZeroOffset(rc))
     return false;
 
-  // The check assumes the rewrite relies on completely static shape info.
+  // Dynamic sizes/strides prevent precise reasoning about the underlying
+  // reinterpret_cast, so only fully static shape metadata is accepted.
   if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
       llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
     return false;
 
-  // The check assumes the rewrite supports shapes with at most one non-unit
-  // dimension. This excludes underlying multi-dimensional layouts and keeps the
-  // rewrite limited to unit-dim insertion/removal `reinterpret_cast`s.
-  std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
-      getSingleNonUnitDimInfo(inputTy);
-  std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
-      getSingleNonUnitDimInfo(outputTy);
-  // Bail out early if nonUnitDims don't follow rewrite assumptions.
+  // Only shapes with at most one non-unit dimension are accepted. This rules
+  // out more general multi-dimensional reinterpret_casts and restricts the
+  // helper to unit-dim insertion/removal around a single logical dimension.
+  std::optional<ShapeInfoFor1DMemRef> inputNonUnitDim =
+      getShapeInfoFor1DMemRef(inputTy);
+  std::optional<ShapeInfoFor1DMemRef> outputNonUnitDim =
+      getShapeInfoFor1DMemRef(outputTy);
+  // Bail out if either type does not satisfy the single-boundary-non-unit-dim
+  // restriction described above.
   if (!inputNonUnitDim || !outputNonUnitDim)
     return false;
 
   // The source and result must either both have a single non-unit dimension
   // or both be all-ones.
-  if (inputNonUnitDim->exists != outputNonUnitDim->exists)
+  if (inputNonUnitDim->allOnes != outputNonUnitDim->allOnes)
     return false;
-  if (!inputNonUnitDim->exists)
+  if (inputNonUnitDim->allOnes)
     return true;
 
   // The preserved non-unit dimension must have the same size.
-  if (inputTy.getDimSize(inputNonUnitDim->isOnLeft ? 0
-                                                   : inputTy.getRank() - 1) !=
-      outputTy.getDimSize(outputNonUnitDim->isOnLeft ? 0
-                                                     : outputTy.getRank() - 1))
+  if (inputTy.getDimSize(
+          inputNonUnitDim->isLeadingDimNonUnit ? 0 : inputTy.getRank() - 1) !=
+      outputTy.getDimSize(
+          outputNonUnitDim->isLeadingDimNonUnit ? 0 : outputTy.getRank() - 1))
     return false;
 
   // If both sides have rank > 1, the non-unit dimension must be on the same
   // boundary. Rank-1 MemRefs are accepted against either boundary.
   if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
-      inputNonUnitDim->isOnLeft != outputNonUnitDim->isOnLeft)
+      inputNonUnitDim->isLeadingDimNonUnit !=
+          outputNonUnitDim->isLeadingDimNonUnit)
     return false;
 
   return true;
@@ -324,12 +327,35 @@ static bool areIndicesInBounds(memref::LoadOp load) {
 
   for (auto [pos, idx] : llvm::enumerate(load.getIndices())) {
     // FIXME: This should be ensured by the memref.load semantics.
+    // This rejects only explicit constant OOB indices. Dynamic/non-constant
+    // indices are not filtered here.
     if (isConstantIndexExplicitlyOutOfBounds(idx, rcOutputTy.getDimSize(pos)))
       return false;
   }
   return true;
 }
 
+/// Rewrites `memref.load` through a pure rank-only `reinterpret_cast` by
+/// mapping the load indices directly onto the source MemRef.
+
+/// Shape restriction gated by isPureRankExpansionOrCollapsingRC().
+///
+/// BEFORE (rank expansion)
+///   %view = memref.reinterpret_cast %src
+///     : memref<Nxf32> to memref<1x1xNxf32>
+///   %v = memref.load %view[%c0, %c0, %i] : memref<1x1xNxf32>
+///
+/// AFTER
+///   %v = memref.load %src[%i] : memref<Nxf32>
+///
+/// BEFORE (rank collapsing)
+///   %view = memref.reinterpret_cast %src
+///     : memref<1x1xNxf32> to memref<Nxf32>
+///   %v = memref.load %view[%i] : memref<Nxf32>
+///
+/// AFTER
+///   %c0 = arith.constant 0 : index
+///   %v = memref.load %src[%c0, %c0, %i] : memref<1x1xNxf32>
 struct RewriteLoadFromReinterpretCast
     : public OpRewritePattern<memref::LoadOp> {
 public:
@@ -369,14 +395,14 @@ struct RewriteLoadFromReinterpretCast
     //   collapsing : rcInput has the higher rank
     //
     // Example:
-    //   memref<999>     -> memref<1x1x999>   : extra dims to the left
-    //   memref<999x1x1> -> memref<999>       : extra dims to the right
+    //   memref<999>     -> memref<1x1x999>   : leading extra dims
+    //   memref<999x1x1> -> memref<999>       : trailing extra dims
     MemRefType expandedTy =
         rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
-    std::optional<SingleNonUnitDimInfo> expandedNonUnitDim =
-        getSingleNonUnitDimInfo(expandedTy);
+    std::optional<ShapeInfoFor1DMemRef> expandedNonUnitDim =
+        getShapeInfoFor1DMemRef(expandedTy);
     assert(expandedNonUnitDim && "expected a single boundary non-unit dim");
-    bool keepLeadingIndices = expandedNonUnitDim->isOnLeft;
+    bool keepLeadingIndices = expandedNonUnitDim->isLeadingDimNonUnit;
 
     if (rcOutputRank >= rcInputRank) {
       // Rank expansion:
@@ -419,9 +445,6 @@ struct RewriteLoadFromReinterpretCast
     if (rc.getResult().hasOneUse())
       rewriter.eraseOp(rc);
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, rcInput, rcInputIdxs);
-
-    // Do not erase the reinterpret_cast here. After the load is rewritten it
-    // may become dead, and canonical DCE can remove it.
     return success();
   }
 };
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index 9b563abadb3aa..67997bf78da0b 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -238,8 +238,8 @@ func.func private @expand_scalar(%src : memref<1xi64>) {
   %c0 = arith.constant 0 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
-      to memref<1x1x1xi64>
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
+    : memref<1xi64> to memref<1x1x1xi64>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xi64>
   %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xi64>
   return
@@ -253,8 +253,8 @@ func.func private @collapse_scalar(%src : memref<1x1x1xi64>) {
   %c0 = arith.constant 0 : index
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
-      to memref<1x1xi64>
+    to offset: [0], sizes: [1, 1], strides: [1, 1]
+    : memref<1x1x1xi64> to memref<1x1xi64>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C0]]] : memref<1x1x1xi64>
   %0 = memref.load %reinterpret_cast[%c0, %c0] : memref<1x1xi64>
   return
@@ -268,7 +268,7 @@ func.func private @expand_left_vector(%src : memref<999xi64>) {
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
-      : memref<999xi64> to memref<1x1x999xi64>
+    : memref<999xi64> to memref<1x1x999xi64>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
   %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
   return
@@ -284,7 +284,7 @@ func.func private @expand_left_vector_dynamic_index(%i : index,
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
-      : memref<999xi64> to memref<1x1x999xi64>
+    : memref<999xi64> to memref<1x1x999xi64>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[I]]] : memref<999xi64>
   %0 = memref.load %reinterpret_cast[%c0, %c0, %i] : memref<1x1x999xi64>
   return
@@ -299,7 +299,7 @@ func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [999], strides: [1]
-      : memref<1x1x999xi64> to memref<999xi64>
+    : memref<1x1x999xi64> to memref<999xi64>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
   %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
   return
@@ -316,7 +316,7 @@ func.func private @partial_expand_left_vector(
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
-      : memref<1x999xf32> to memref<1x1x999xf32>
+    : memref<1x999xf32> to memref<1x1x999xf32>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<1x999xf32>
   %0 = memref.load %reinterpret_cast[%c0, %c0, %c1]
     : memref<1x1x999xf32>
@@ -335,7 +335,7 @@ func.func private @partial_collapse_left_vector(
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 999], strides: [999, 1]
-      : memref<1x1x999xf32> to memref<1x999xf32>
+    : memref<1x1x999xf32> to memref<1x999xf32>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x999xf32>
   %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x999xf32>
   return
@@ -349,7 +349,7 @@ func.func private @expand_right_vector(%src : memref<999xi64>) {
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
-      : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
+    : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
   %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<999x1x1xi64,
     strided<[1, 999, 999]>>
@@ -380,7 +380,7 @@ func.func private @collapse_right_vector_dynamic_index(%i : index,
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [999], strides: [1]
-      : memref<999x1x1xi64> to memref<999xi64>
+    : memref<999x1x1xi64> to memref<999xi64>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[I]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
   %0 = memref.load %reinterpret_cast[%i] : memref<999xi64>
   return
@@ -397,7 +397,7 @@ func.func private @partial_expand_right_vector(
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
-      : memref<999x1xf32> to memref<999x1x1xf32, strided<[1, 999, 999]>>
+    : memref<999x1xf32> to memref<999x1x1xf32, strided<[1, 999, 999]>>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<999x1xf32>
   %0 = memref.load %reinterpret_cast[%c1, %c0, %c0]
     : memref<999x1x1xf32, strided<[1, 999, 999]>>
@@ -416,7 +416,7 @@ func.func private @partial_collapse_right_vector(
   // CHECK-NOT:   memref.reinterpret_cast
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [999, 1], strides: [1, 999]
-      : memref<999x1x1xf32> to memref<999x1xf32, strided<[1, 999]>>
+    : memref<999x1x1xf32> to memref<999x1xf32, strided<[1, 999]>>
   // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]]] : memref<999x1x1xf32>
   %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<999x1xf32,
     strided<[1, 999]>>
@@ -435,8 +435,8 @@ func.func private @negative_nonzero_offset(
   %c1 = arith.constant 1 : index
   // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
-    to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
-      to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+    to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1]
+    : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
   // CHECK:       memref.load %[[RC]]
   %0 = memref.load %reinterpret_cast[%c0, %c0, %c1]
     : memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
@@ -451,7 +451,7 @@ func.func private @negative_dynamic_shape(%dim : index, %i : index,
   // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, %dim], strides: [1, 1]
-      : memref<?xi64> to memref<1x?xi64>
+    : memref<?xi64> to memref<1x?xi64>
   // CHECK:       memref.load %[[RC]]
   %0 = memref.load %reinterpret_cast[%c0, %i] : memref<1x?xi64>
   return
@@ -466,8 +466,7 @@ func.func private @negative_dynamic_stride(%stride0: index,
   // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
-    : memref<1x108xi64>
-      to memref<1x1xi64, strided<[?, ?]>>
+    : memref<1x108xi64> to memref<1x1xi64, strided<[?, ?]>>
   // CHECK:       memref.load %[[RC]]
   %0 = memref.load %reinterpret_cast[%c0, %c1]
     : memref<1x1xi64, strided<[?, ?]>>
@@ -483,7 +482,7 @@ func.func private @negative_multiple_non_unit_dims(
   // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [2, 100], strides: [100, 1]
-      : memref<2x1x1x100xf32> to memref<2x100xf32>
+    : memref<2x1x1x100xf32> to memref<2x100xf32>
   // CHECK:       memref.load %[[RC]]
   %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<2x100xf32>
   return
@@ -498,7 +497,7 @@ func.func private @negative_inner_non_unit_dims(
   // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
-      : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+    : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
   // CHECK:       memref.load %[[RC]]
   %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
     strided<[100, 1, 100]>>
@@ -514,7 +513,7 @@ func.func private @negative_diff_non_unit_dims(
   // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]]
   %reinterpret_cast = memref.reinterpret_cast %src
     to offset: [0], sizes: [1, 99], strides: [99, 1]
-      : memref<1x1x1x100xf32> to memref<1x99xf32>
+    : memref<1x1x1x100xf32> to memref<1x99xf32>
   // CHECK:       memref.load %[[RC]]
   %0 = memref.load %reinterpret_cast[%c0, %c98] : memref<1x99xf32>
   return



More information about the Mlir-commits mailing list