[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

ioana ghiban llvmlistbot at llvm.org
Wed Mar 25 04:14:33 PDT 2026


https://github.com/ioghiban updated https://github.com/llvm/llvm-project/pull/188459

>From f285b1dc9d61d679c4d49c52ce8a9b7ac3d1fef5 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Tue, 24 Mar 2026 17:34:52 +0100
Subject: [PATCH] [memref] Simplify loads from reinterpret_cast of 1D
 contiguous memrefs

Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and tests before submission.
---
 .../Transforms/ElideReinterpretCast.cpp       | 241 +++++++++++++-
 .../MemRef/elide-reinterpret-cast.mlir        | 309 +++++++++++++++++-
 2 files changed, 548 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index dc139d892f5e5..49d764fc5aee1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <cassert>
@@ -195,6 +196,237 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
+static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
+
+static bool isPureRankReshape(memref::ReinterpretCastOp rc, memref::LoadOp op) {
+  auto inputTy = cast<MemRefType>(rc.getSource().getType());
+  auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+  // This fold only handles reinterpret_casts that behave like pure rank
+  // reshapes of a single logical dimension:
+  //
+  //   - all metadata is static
+  //   - offset is 0
+  //   - source/result each have at most one non-unit dim
+  //   - if a non-unit dim exists, it is at the left or right boundary
+  //
+  // Examples accepted by this shape restriction:
+  //   memref<999xf32>       <-> memref<1x1x999xf32>
+  //   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
+  //   memref<100x1xf32>     <-> memref<100x1x1xf32>
+  //
+  // General reinterpret_casts are intentionally rejected.
+
+  auto offsets = rc.getStaticOffsets();
+  assert(offsets.size() == 1 && "Expecting single offset");
+
+  // The rewrite drops the reinterpret_cast and remaps indices directly to the
+  // source memref. That is only correct if there is no storage shift.
+  if (ShapedType::isDynamic(offsets[0]) || offsets[0] != 0)
+    return false;
+
+  auto sizes = rc.getStaticSizes();
+  auto strides = rc.getStaticStrides();
+
+  // Require fully static metadata. The fold relies on knowing exactly which
+  // dimensions are unit dimensions and which indices may be ignored.
+  if (llvm::any_of(sizes, ShapedType::isDynamic))
+    return false;
+  if (llvm::any_of(strides, ShapedType::isDynamic))
+    return false;
+
+  // Count non-unit dims and remember their positions.
+  //
+  // The rewrite supports shapes with at most one non-unit dimension.
+  // This excludes underlying multi-dimensional layouts and keeps the
+  // fold limited to unit-dim insertion/removal reshapes.
+  unsigned inputRank = inputTy.getRank();
+  int inputNonUnitCount = 0;
+  int64_t inputNonUnitSize = 1;
+  unsigned inputNonUnitPos = 0;
+  for (unsigned i = 0; i < inputRank; ++i) {
+    if (inputTy.getDimSize(i) != 1) {
+      ++inputNonUnitCount;
+      inputNonUnitPos = i;
+      inputNonUnitSize = inputTy.getDimSize(i);
+    }
+  }
+
+  unsigned outputRank = outputTy.getRank();
+  int outputNonUnitCount = 0;
+  int64_t outputNonUnitSize = 1;
+  unsigned outputNonUnitPos = 0;
+  for (unsigned i = 0; i < outputRank; ++i) {
+    if (outputTy.getDimSize(i) != 1) {
+      ++outputNonUnitCount;
+      outputNonUnitPos = i;
+      outputNonUnitSize = outputTy.getDimSize(i);
+    }
+  }
+
+  // Reject reshapes with > 1 non-unit-dimension.
+  //
+  // The source and result must have the same number of non-unit dimensions:
+  // either both are all-ones, or both have exactly one non-unit dimension.
+  if (inputNonUnitCount > 1 || outputNonUnitCount > 1 ||
+      inputNonUnitCount != outputNonUnitCount)
+    return false;
+
+  // If there is a non-unit dimension, it must live at the same boundary
+  // (first or last dimension) on both input and output memrefs.
+  // The rewrite logic for preserving the load index is exclusive to these
+  // cases.
+  if (inputNonUnitCount == 1) {
+    auto isBoundary = [](unsigned pos, unsigned rank) {
+      return pos == 0 || pos == rank - 1;
+    };
+    if (!isBoundary(inputNonUnitPos, inputRank) ||
+        !isBoundary(outputNonUnitPos, outputRank))
+      return false;
+  }
+
+  // Size of non-unit dimension must be the same
+  if (inputNonUnitCount == 1 && outputNonUnitCount == 1 &&
+      inputNonUnitSize != outputNonUnitSize)
+    return false;
+
+  SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+  SmallVector<unsigned> nonZeroIdxPositions;
+  nonZeroIdxPositions.reserve(idxs.size());
+
+  // Record non-zero indices.
+  //
+  // During rank expansion, the rewrite drops the extra unit-dimension indices.
+  // That is only semantics-preserving if every dropped index is zero.
+  for (auto [pos, idx] : llvm::enumerate(idxs)) {
+    if (!isConstZero(idx))
+      nonZeroIdxPositions.push_back(pos);
+  }
+
+  // Position of the unique non-unit dim in the output, if present:
+  //   - 0            for shapes like [N, 1, 1]
+  //   - outputRank-1 for shapes like [1, 1, N]
+  //
+  // For the all-ones case, treat it like the "non-unit on the right" case.
+  unsigned nonUnitDimPos =
+      (outputNonUnitCount == 1 && outputTy.getDimSize(0) != 1) ? 0
+                                                               : outputRank - 1;
+
+  if (outputRank >= inputRank) {
+    // Rank expansion case.
+    //
+    // The rewrite keeps only inputRank indices. Any non-zero index in an
+    // expanded unit dimension that would be discarded makes the fold invalid.
+    if (nonUnitDimPos == 0) {
+      // Expansion on the right: keep the leftmost inputRank indices.
+      // Therefore any non-zero index in the suffix would be lost.
+      for (unsigned pos : nonZeroIdxPositions) {
+        if (pos >= inputRank)
+          return false;
+      }
+    } else {
+      // Expansion on the left: keep the rightmost inputRank indices.
+      // Therefore any non-zero index in the prefix would be lost.
+      unsigned firstValidPos = outputRank - inputRank;
+      for (unsigned pos : nonZeroIdxPositions) {
+        if (pos < firstValidPos)
+          return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::LoadOp op,
+                                PatternRewriter &rewriter) const override {
+    auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+    if (!rc)
+      return failure();
+
+    // This fold is only correct for the narrow "pure rank reshape of a single
+    // logical dimension" cases accepted by isPureRankReshape().
+    if (!isPureRankReshape(rc, op))
+      return failure();
+
+    auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
+    auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
+
+    int64_t rcOutputRank = rcOutputTy.getRank();
+    int64_t rcInputRank = rcInputTy.getRank();
+
+    SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+    SmallVector<Value> rcInputIdxs;
+
+    // The fold only supports reshapes with at most one non-unit dimension,
+    // located at the left or right boundary.
+    //
+    // The higher-rank side tells which side the reshape has expanded/collapsed.
+    //
+    //   expansion: rcOutput has the higher rank
+    //   collapse : rcInput has the higher rank
+    //
+    // Example:
+    //   memref<999>     -> memref<1x1x999>   : extra dims to the left
+    //   memref<999x1x1> -> memref<999>       : extra dims to the right
+    MemRefType expandedTy =
+        rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
+    bool nonUnitOnLeft = expandedTy.getDimSize(0) != 1;
+
+    if (rcOutputRank >= rcInputRank) {
+      // Rank expansion:
+      //   memref<N>   -> memref<1x1xN>   : keep the last rcInputRank indices
+      //   memref<N>   -> memref<Nx1x1>   : keep the first rcInputRank indices
+      //
+      // Any discarded indices are known to be zero from isPureRankReshape().
+      if (nonUnitOnLeft) {
+        for (int64_t dim = 0; dim < rcInputRank; ++dim)
+          rcInputIdxs.push_back(idxs[dim]);
+      } else {
+        for (int64_t dim = 0; dim < rcInputRank; ++dim)
+          rcInputIdxs.push_back(idxs[rcOutputRank - rcInputRank + dim]);
+      }
+    } else {
+      // Rank collapse:
+      //   memref<1x1xN> -> memref<N>      : reinsert leading zeros
+      //   memref<Nx1x1> -> memref<N>      : reinsert trailing zeros
+      //
+      // The collapsed-away dimensions are unit dims, so readding them with
+      // zero indices preserves semantics.
+      Value c0 = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+      int64_t rankDiff = rcInputRank - rcOutputRank;
+
+      if (nonUnitOnLeft) {
+        rcInputIdxs.append(idxs.begin(), idxs.end());
+        rcInputIdxs.append(rankDiff, c0);
+      } else {
+        rcInputIdxs.append(rankDiff, c0);
+        rcInputIdxs.append(idxs.begin(), idxs.end());
+      }
+    }
+
+    // Sanity check: rewritten load must index the source memref with exactly
+    // as many indices as the rank.
+    if ((int64_t)rcInputIdxs.size() != rcInputRank)
+      return failure();
+
+    auto rcInput = rc.getSource();
+    // If the only user of rc is the current Op (which is about to be erased),
+    // we can safely erase it.
+    if (rc.getResult().hasOneUse())
+      rewriter.eraseOp(rc);
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, rcInput, rcInputIdxs);
+
+    // Do not erase the reinterpret_cast here. After the load is rewritten it
+    // may become dead, and canonical DCE can remove it.
+    return success();
+  }
+};
+
 struct ElideReinterpretCastPass
     : public memref::impl::ElideReinterpretCastPassBase<
           ElideReinterpretCastPass> {
@@ -210,6 +442,12 @@ struct ElideReinterpretCastPass
         return true;
       return !isScalarSlice(rc);
     });
+    target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp op) {
+      auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+      if (!rc)
+        return true;
+      return !isPureRankReshape(rc, op);
+    });
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -221,5 +459,6 @@ struct ElideReinterpretCastPass
 
 void mlir::memref::populateElideReinterpretCastPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
+  patterns.add<CopyToScalarLoadAndStore, FoldReinterpretCastLoad>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index da47562e9c0d6..5733c97ea8f3b 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -memref-elide-reinterpret-cast %s \
+// RUN: | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // Positive tests
@@ -220,3 +221,309 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
   : memref<1x1xf32> to memref<1x1xf32>
   return
 }
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+      to memref<1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<999xi64> to memref<1x1x999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<1x1x999xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @reshape_expand_left_inner_unit_dims(
+    %src : memref<1x108xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+      : memref<1x108xf32> to memref<1x1x1x108xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+    : memref<1x1x1x108xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @reshape_collapse_left_inner_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 100], strides: [100, 1]
+      : memref<1x1x1x100xf32> to memref<1x100xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1x100xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x100xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
+      : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<999x1x1xi64,
+    strided<[1, 999, 999]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999x1x1xi64>) {
+func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<999x1x1xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @reshape_expand_right_inner_unit_dims(
+    %src : memref<108x1xf32>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<100x1x1x1xf32>) {
+func.func private @reshape_collapse_right_inner_unit_dims(
+    %src : memref<100x1x1x1xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [100, 1], strides: [1, 100]
+      : memref<100x1x1x1xf32> to memref<100x1xf32, strided<[1, 100]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]], %[[C0_0]]] : memref<100x1x1x1xf32>
+  %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<100x1xf32,
+    strided<[1, 100]>>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_reshape_nonzero_offset(
+    %src : memref<1xi64>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1]
+    : memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_dynamic_shape(
+// CHECK-SAME:   %[[DIM:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<?xi64>
+func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
+    %src : memref<?xi64>) {
+  %c0 = arith.constant 0 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, %[[DIM]]], strides: [1, 1] : memref<?xi64> to memref<1x?xi64>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, %dim], strides: [1, 1]
+      : memref<?xi64> to memref<1x?xi64>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %i] : memref<1x?xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_dynamic_stride(
+// CHECK-SAME:   %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xi64>
+func.func private @negative_reshape_dynamic_stride(%stride0: index,
+    %stride1: index, %src : memref<1x108xi64>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1], strides: [%[[STR0]], %[[STR1]]] : memref<1x108xi64> to memref<1x1xi64, strided<[?, ?]>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+    : memref<1x108xi64>
+      to memref<1x1xi64, strided<[?, ?]>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1]
+    : memref<1x1xi64, strided<[?, ?]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_multiple_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<2x1x1x100xf32>) {
+func.func private @negative_reshape_multiple_non_unit_dims(
+    %src : memref<2x1x1x100xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [2, 100], strides: [100, 1] : memref<2x1x1x100xf32> to memref<2x100xf32>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [2, 100], strides: [100, 1]
+      : memref<2x1x1x100xf32> to memref<2x100xf32>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<2x100xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_diff_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_reshape_diff_non_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 101], strides: [101, 1] : memref<1x1x1x100xf32> to memref<1x101xf32>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 101], strides: [101, 1]
+      : memref<1x1x1x100xf32> to memref<1x101xf32>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x101xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_inner_non_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_reshape_inner_non_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100] : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
+      : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
+    strided<[100, 1, 100]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_expand_left_discarded_indices(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @negative_reshape_expand_left_discarded_indices(
+    %src : memref<1x108xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1] : memref<1x108xf32> to memref<1x1x1x108xf32>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+      : memref<1x108xf32> to memref<1x1x1x108xf32>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+    : memref<1x1x1x108xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_expand_right_discarded_indices(
+// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @negative_reshape_expand_right_discarded_indices(
+    %src : memref<108x1xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108] : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  // CHECK:       memref.load %[[RC]]
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  return
+}



More information about the Mlir-commits mailing list