[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 23 01:42:26 PDT 2026


================
@@ -195,6 +196,237 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
+static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
+
+static bool isPureRankReshape(memref::ReinterpretCastOp rc, memref::LoadOp op) {
+  auto inputTy = cast<MemRefType>(rc.getSource().getType());
+  auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+  // This fold only handles reinterpret_casts that behave like pure rank
+  // reshapes of a single logical dimension:
+  //
+  //   - all metadata is static
+  //   - offset is 0
+  //   - source/result each have at most one non-unit dim
+  //   - if a non-unit dim exists, it is at the left or right boundary
+  //
+  // Examples accepted by this shape restriction:
+  //   memref<999xf32>       <-> memref<1x1x999xf32>
+  //   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
+  //   memref<100x1xf32>     <-> memref<100x1x1xf32>
+  //
+  // General reinterpret_casts are intentionally rejected.
+
+  auto offsets = rc.getStaticOffsets();
+  assert(offsets.size() == 1 && "Expecting single offset");
+
+  // The rewrite drops the reinterpret_cast and remaps indices directly to the
+  // source memref. That is only correct if there is no storage shift.
+  if (ShapedType::isDynamic(offsets[0]) || offsets[0] != 0)
+    return false;
+
+  auto sizes = rc.getStaticSizes();
+  auto strides = rc.getStaticStrides();
+
+  // Require fully static metadata. The fold relies on knowing exactly which
+  // dimensions are unit dimensions and which indices may be ignored.
+  if (llvm::any_of(sizes, ShapedType::isDynamic))
+    return false;
+  if (llvm::any_of(strides, ShapedType::isDynamic))
+    return false;
+
+  // Count non-unit dims and remember their positions.
+  //
+  // The rewrite supports shapes with at most one non-unit dimension.
+  // This excludes underlying multi-dimensional layouts and keeps the
+  // fold limited to unit-dim insertion/removal reshapes.
+  unsigned inputRank = inputTy.getRank();
+  int inputNonUnitCount = 0;
+  int64_t inputNonUnitSize = 1;
+  unsigned inputNonUnitPos = 0;
+  for (unsigned i = 0; i < inputRank; ++i) {
+    if (inputTy.getDimSize(i) != 1) {
+      ++inputNonUnitCount;
+      inputNonUnitPos = i;
+      inputNonUnitSize = inputTy.getDimSize(i);
+    }
+  }
+
+  unsigned outputRank = outputTy.getRank();
+  int outputNonUnitCount = 0;
+  int64_t outputNonUnitSize = 1;
+  unsigned outputNonUnitPos = 0;
+  for (unsigned i = 0; i < outputRank; ++i) {
+    if (outputTy.getDimSize(i) != 1) {
+      ++outputNonUnitCount;
+      outputNonUnitPos = i;
+      outputNonUnitSize = outputTy.getDimSize(i);
+    }
+  }
+
+  // Reject reshapes with > 1 non-unit-dimension.
+  //
+  // The source and result must have the same number of non-unit dimensions:
+  // either both are all-ones, or both have exactly one non-unit dimension.
+  if (inputNonUnitCount > 1 || outputNonUnitCount > 1 ||
+      inputNonUnitCount != outputNonUnitCount)
+    return false;
+
+  // If there is a non-unit dimension, it must live at the same boundary
+  // (first or last dimension) on both input and output memrefs.
+  // The rewrite logic for preserving the load index is exclusive to these
+  // cases.
+  if (inputNonUnitCount == 1) {
+    auto isBoundary = [](unsigned pos, unsigned rank) {
+      return pos == 0 || pos == rank - 1;
+    };
+    if (!isBoundary(inputNonUnitPos, inputRank) ||
+        !isBoundary(outputNonUnitPos, outputRank))
+      return false;
+  }
+
+  // Size of non-unit dimension must be the same
+  if (inputNonUnitCount == 1 && outputNonUnitCount == 1 &&
+      inputNonUnitSize != outputNonUnitSize)
+    return false;
+
+  SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+  SmallVector<unsigned> nonZeroIdxPositions;
+  nonZeroIdxPositions.reserve(idxs.size());
+
+  // Record non-zero indices.
+  //
+  // During rank expansion, the rewrite drops the extra unit-dimension indices.
+  // That is only semantics-preserving if every dropped index is zero.
+  for (auto [pos, idx] : llvm::enumerate(idxs)) {
+    if (!isConstZero(idx))
+      nonZeroIdxPositions.push_back(pos);
+  }
+
+  // Position of the unique non-unit dim in the output, if present:
+  //   - 0            for shapes like [N, 1, 1]
+  //   - outputRank-1 for shapes like [1, 1, N]
+  //
+  // For the all-ones case, treat it like the "non-unit on the right" case.
+  unsigned nonUnitDimPos =
+      (outputNonUnitCount == 1 && outputTy.getDimSize(0) != 1) ? 0
+                                                               : outputRank - 1;
+
+  if (outputRank >= inputRank) {
+    // Rank expansion case.
+    //
+    // The rewrite keeps only inputRank indices. Any non-zero index in an
+    // expanded unit dimension that would be discarded makes the fold invalid.
+    if (nonUnitDimPos == 0) {
+      // Expansion on the right: keep the leftmost inputRank indices.
+      // Therefore any non-zero index in the suffix would be lost.
+      for (unsigned pos : nonZeroIdxPositions) {
+        if (pos >= inputRank)
+          return false;
+      }
+    } else {
+      // Expansion on the left: keep the rightmost inputRank indices.
+      // Therefore any non-zero index in the prefix would be lost.
+      unsigned firstValidPos = outputRank - inputRank;
+      for (unsigned pos : nonZeroIdxPositions) {
+        if (pos < firstValidPos)
+          return false;
+      }
+    }
+  }
+
+  return true;
+}
+
----------------
banach-space wrote:

By "high-level" comment I meant something like the following (with `BEFORE` + `AFTER`): https://github.com/llvm/llvm-project/blob/7aa2b040236bfa8b60ebec69af60f5a334ee160e/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp?plain=1#L113-L150

Note, the emphasis is on "high-level" rather than long/detailed ;-)

https://github.com/llvm/llvm-project/pull/188459


More information about the Mlir-commits mailing list