[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 23 05:39:42 PDT 2026


================
@@ -196,6 +198,259 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
   }
 };
 
+/// Captures info about MemRefs that are effectively 1D (the leading or trailing
+/// dims are all 1). The only accepted non-unit dim is either the leading of the
+/// trailing dim.
+///
+/// Examples:
+/// memref<1x1x4xf32>, memref<4x1x1xf32>, memref<1x1x1xf32>
+///
+struct ShapeInfoFor1DMemRef {
+  // Are all dims == 1? `false` means that there is exactly one dim != 1.
+  bool allOnes = true;
+  // If there is a non-unit boundary dim, is it the leading or the trailing dim?
+  bool isLeadingDimNonUnit = false;
+};
+
+/// Returns information about a MemRef if it contains at most one non-unit
+/// dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
+static std::optional<ShapeInfoFor1DMemRef>
+getShapeInfoFor1DMemRef(MemRefType type) {
+  ArrayRef<int64_t> shape = type.getShape();
+  int64_t nonUnitCount =
+      llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+  // Return default values if missing nonUnitDim
+  if (nonUnitCount == 0)
+    return ShapeInfoFor1DMemRef{};
+  // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
----------------
banach-space wrote:

[nit] What's `nonUnitDim requirements `? As in, `nonUnitDim` is spelled as a variable, but there is no such variable ;-)

I know what you mean here, but only because I am familiar with the design. I think that folks reading this in the future might find it confusing ;-)

https://github.com/llvm/llvm-project/pull/188459


More information about the Mlir-commits mailing list