[Mlir-commits] [mlir] [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (PR #188980)

Hocky Yudhiono llvmlistbot at llvm.org
Thu Apr 2 04:51:24 PDT 2026


================
@@ -230,19 +210,98 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
       .Default([&](auto op) { return success(); });
 }
 
+// Pattern for memref::AllocOp and memref::AllocaOp.
+//
+// The "source" memref for these ops IS the op's own result, so the generic
+// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
+// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
+// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
+// new ReinterpretCastOp, leaving the earlier ops with forward references
+// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
+//
+// Instead, sizes and strides are computed from the op's operands and type
+// (which all dominate the op), avoiding any reference to op.result until the
+// final replaceOpWithNewOp inside castAllocResult.
+template <typename T>
+struct AllocLikeFlattenPattern : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
+      return failure();
+
+    Location loc = op->getLoc();
+    auto memrefType = cast<MemRefType>(op.getType());
+    auto elemType = memrefType.getElementType();
+    if (!elemType.isIntOrFloat())
+      return failure();
+    unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
+
+    SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+    int64_t staticOffset;
+    SmallVector<int64_t> staticStrides;
+    if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
+      return failure();
+    // Reject non-zero or dynamic base offsets (alloc results start at 0).
+    if (staticOffset != 0)
+      return failure();
+    SmallVector<OpFoldResult> strides;
+    strides.reserve(staticStrides.size());
+    for (int64_t stride : staticStrides) {
+      if (stride == ShapedType::kDynamic)
+        return failure();
+      strides.push_back(rewriter.getIndexAttr(stride));
+    }
+
+    // Compute the linearized flat size from sizes and strides (no SSA ops
+    // referencing op.result are created here).
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedOffset;
+    std::tie(linearizedInfo, linearizedOffset) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
+            sizes, strides);
+    (void)linearizedOffset;
+
+    // Build the flat 1-D MemRefType. The linearized size may be static or
+    // dynamic (OpFoldResult of either IntegerAttr or a Value).
+    int64_t flatDimSize = ShapedType::kDynamic;
+    if (auto attr = dyn_cast<Attribute>(linearizedInfo.linearizedSize))
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+        flatDimSize = intAttr.getInt();
+
+    auto flatMemrefType =
+        MemRefType::get({flatDimSize}, memrefType.getElementType(),
+                        StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
+                        memrefType.getMemorySpace());
+
+    // Collect the flat dynamic-size operand (empty for fully-static case).
+    SmallVector<Value, 1> dynSizes;
+    if (flatDimSize == ShapedType::kDynamic)
+      dynSizes.push_back(getValueFromOpFoldResult(
+          rewriter, loc, linearizedInfo.linearizedSize));
+
+    auto newOp = T::create(rewriter, loc, flatMemrefType, dynSizes,
----------------
hockyy wrote:

```suggestion
    auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
```

https://github.com/llvm/llvm-project/pull/188980


More information about the Mlir-commits mailing list