[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 23 05:39:42 PDT 2026


================
@@ -196,6 +198,234 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
+/// Describes the unique non-unit dimension of a MemRef shape.
+///
+/// This helper is only used for shapes that have at most one non-unit
+/// dimension. `exists` is false for all-ones shapes. Otherwise, `isOnLeft`
+/// indicates whether the non-unit dimension is on the left boundary.
+///
+/// If `exists` is true and `isOnLeft` is false, the non-unit dimension is on
+/// the right boundary. Rank-1 non-unit MemRefs are treated as matching both
+/// boundaries and callers that care about the right boundary must account for
+/// that from the MemRef type.
+struct SingleNonUnitDimInfo {
+  bool exists = false;
+  bool isOnLeft = false;
+};
+
+/// Returns information about a MemRef if it contains at most one non-unit
+/// dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
+static std::optional<SingleNonUnitDimInfo>
+getSingleNonUnitDimInfo(MemRefType type) {
+  ArrayRef<int64_t> shape = type.getShape();
+  int64_t nonUnitCount =
+      llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+  // Return default values if missing nonUnitDim
+  if (nonUnitCount == 0)
+    return SingleNonUnitDimInfo{};
+  // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
+  if (nonUnitCount > 1)
+    return std::nullopt;
+
+  bool isOnLeft = shape.front() != 1;
+  // Return no info if MemRef breaks nonUnitDim requirements (nonUnitDim in
+  // non-boundary pos)
+  if (!isOnLeft && shape.back() == 1)
+    return std::nullopt;
+
+  return SingleNonUnitDimInfo{/*exists=*/true, isOnLeft};
+}
+
+static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
+  ArrayRef<int64_t> offsets = rc.getStaticOffsets();
+  // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
+  // only a single offset. That should be fixed at the op definition level.
+  assert(offsets.size() == 1 && "Expecting single offset");
+  return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
+}
+
+static std::optional<int64_t> getConstantIndex(Value v) {
+  if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+    return cst.value();
+  return std::nullopt;
+}
+
+static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
+                                                 int64_t upperBound) {
+  std::optional<int64_t> idxVal = getConstantIndex(idx);
+  return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
+}
+
+/// Examples accepted by this shape restriction:
+///   memref<999xf32>       <-> memref<1x1x999xf32>
+///   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
+///   memref<100x1xf32>     <-> memref<100x1x1xf32>
+///   memref<1>             <-> memref<1x1x1>
+///
+/// General reinterpret_casts are intentionally rejected.
+static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
+  auto inputTy = cast<MemRefType>(rc.getSource().getType());
+  auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+  // This rewrite assumes "index re-use" and misses "index
+  // re-write/adjustment" logic, hence the requirement for the offset to be 0.
+  // Thus, storage shift and statically unknown offsets are rejected.
+  if (!hasStaticZeroOffset(rc))
+    return false;
+
+  // The check assumes the rewrite relies on completely static shape info.
+  if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
+      llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
+    return false;
+
+  // The check assumes the rewrite supports shapes with at most one non-unit
+  // dimension. This excludes underlying multi-dimensional layouts and keeps the
+  // rewrite limited to unit-dim insertion/removal `reinterpret_cast`s.
+  std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+      getSingleNonUnitDimInfo(inputTy);
+  std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+      getSingleNonUnitDimInfo(outputTy);
+  // Bail out early if nonUnitDims don't follow rewrite assumptions.
+  if (!inputNonUnitDim || !outputNonUnitDim)
+    return false;
+
+  // The source and result must either both have a single non-unit dimension
+  // or both be all-ones.
+  if (inputNonUnitDim->exists != outputNonUnitDim->exists)
+    return false;
+  if (!inputNonUnitDim->exists)
+    return true;
+
+  // The preserved non-unit dimension must have the same size.
+  if (inputTy.getDimSize(inputNonUnitDim->isOnLeft ? 0
+                                                   : inputTy.getRank() - 1) !=
+      outputTy.getDimSize(outputNonUnitDim->isOnLeft ? 0
+                                                     : outputTy.getRank() - 1))
+    return false;
+
+  // If both sides have rank > 1, the non-unit dimension must be on the same
+  // boundary. Rank-1 MemRefs are accepted against either boundary.
+  if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
+      inputNonUnitDim->isOnLeft != outputNonUnitDim->isOnLeft)
+    return false;
+
+  return true;
+}
+
+/// Checks statically known indices accessed by a load from a pure rank
+/// expansion/collapsing to ensure in-bounds only access. Dynamic indices are
+/// accepted.
+static bool areIndicesInBounds(memref::LoadOp load) {
+  auto rc = load.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+  auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
+
+  for (auto [pos, idx] : llvm::enumerate(load.getIndices())) {
+    // FIXME: This should be ensured by the memref.load semantics.
+    if (isConstantIndexExplicitlyOutOfBounds(idx, rcOutputTy.getDimSize(pos)))
+      return false;
+  }
+  return true;
+}
----------------
banach-space wrote:

OK, this explanation helped. I got confused, but that's because I would implement that method differently. That's beside the point though. Still, a high level comment for `isConstantIndexExplicitlyOutOfBounds` would help.

```cpp
// Return true for if the input index is in bounds, i.e. `0 <= idx < upperBound`. Fully dynamic index values (i.e. non-constant) that cannot be analysed are treated as in-bounds.
```

https://github.com/llvm/llvm-project/pull/188459


More information about the Mlir-commits mailing list