[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 23 05:20:43 PDT 2026


================
@@ -196,6 +198,234 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
+/// Describes the unique non-unit dimension of a MemRef shape.
+///
+/// This helper is only used for shapes that have at most one non-unit
+/// dimension. `exists` is false for all-ones shapes. Otherwise, `isOnLeft`
+/// indicates whether the non-unit dimension is on the left boundary.
+///
+/// If `exists` is true and `isOnLeft` is false, the non-unit dimension is on
+/// the right boundary. Rank-1 non-unit MemRefs are treated as matching both
+/// boundaries and callers that care about the right boundary must account for
+/// that from the MemRef type.
+struct SingleNonUnitDimInfo {
+  bool exists = false;
+  bool isOnLeft = false;
+};
+
+/// Returns information about a MemRef if it contains at most one non-unit
+/// dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
+static std::optional<SingleNonUnitDimInfo>
+getSingleNonUnitDimInfo(MemRefType type) {
+  ArrayRef<int64_t> shape = type.getShape();
+  int64_t nonUnitCount =
+      llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+  // Return default values if missing nonUnitDim
+  if (nonUnitCount == 0)
+    return SingleNonUnitDimInfo{};
+  // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
+  if (nonUnitCount > 1)
+    return std::nullopt;
+
+  bool isOnLeft = shape.front() != 1;
+  // Return no info if MemRef breaks nonUnitDim requirements (nonUnitDim in
+  // non-boundary pos)
+  if (!isOnLeft && shape.back() == 1)
+    return std::nullopt;
+
+  return SingleNonUnitDimInfo{/*exists=*/true, isOnLeft};
+}
+
+static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
+  ArrayRef<int64_t> offsets = rc.getStaticOffsets();
+  // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
+  // only a single offset. That should be fixed at the op definition level.
+  assert(offsets.size() == 1 && "Expecting single offset");
+  return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
+}
+
+static std::optional<int64_t> getConstantIndex(Value v) {
+  if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+    return cst.value();
+  return std::nullopt;
+}
+
+static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
+                                                 int64_t upperBound) {
+  std::optional<int64_t> idxVal = getConstantIndex(idx);
+  return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
+}
+
+/// Examples accepted by this shape restriction:
+///   memref<999xf32>       <-> memref<1x1x999xf32>
+///   memref<1x108xf32>     <-> memref<1x1x1x108xf32>
+///   memref<100x1xf32>     <-> memref<100x1x1xf32>
+///   memref<1>             <-> memref<1x1x1>
+///
+/// General reinterpret_casts are intentionally rejected.
+static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
+  auto inputTy = cast<MemRefType>(rc.getSource().getType());
+  auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+  // This rewrite assumes "index re-use" and misses "index
+  // re-write/adjustment" logic, hence the requirement for the offset to be 0.
+  // Thus, storage shift and statically unknown offsets are rejected.
+  if (!hasStaticZeroOffset(rc))
+    return false;
+
+  // The check assumes the rewrite relies on completely static shape info.
+  if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
+      llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
+    return false;
+
+  // The check assumes the rewrite supports shapes with at most one non-unit
+  // dimension. This excludes underlying multi-dimensional layouts and keeps the
+  // rewrite limited to unit-dim insertion/removal `reinterpret_cast`s.
+  std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+      getSingleNonUnitDimInfo(inputTy);
+  std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+      getSingleNonUnitDimInfo(outputTy);
+  // Bail out early if nonUnitDims don't follow rewrite assumptions.
+  if (!inputNonUnitDim || !outputNonUnitDim)
+    return false;
+
+  // The source and result must either both have a single non-unit dimension
+  // or both be all-ones.
+  if (inputNonUnitDim->exists != outputNonUnitDim->exists)
+    return false;
+  if (!inputNonUnitDim->exists)
+    return true;
+
+  // The preserved non-unit dimension must have the same size.
+  if (inputTy.getDimSize(inputNonUnitDim->isOnLeft ? 0
+                                                   : inputTy.getRank() - 1) !=
+      outputTy.getDimSize(outputNonUnitDim->isOnLeft ? 0
+                                                     : outputTy.getRank() - 1))
+    return false;
+
+  // If both sides have rank > 1, the non-unit dimension must be on the same
+  // boundary. Rank-1 MemRefs are accepted against either boundary.
+  if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
+      inputNonUnitDim->isOnLeft != outputNonUnitDim->isOnLeft)
+    return false;
----------------
banach-space wrote:

How about:
```mlir
 %reinterpret_cast = memref.reinterpret_cast %src
    : memref<1x1x1x100xf32> to memref<100x1x1xf32, strided<[100, 1, 100]>>
```
? This is interesting as in isolation, but input and output meet the criteria for your transformation. However, the Op itself does not.

https://github.com/llvm/llvm-project/pull/188459


More information about the Mlir-commits mailing list