[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 30 22:34:39 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ming Yan (NexMing)
<details>
<summary>Changes</summary>
When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.
---
Full diff: https://github.com/llvm/llvm-project/pull/189533.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+55-27)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+2-6)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f31811ad7b98e..8aef3d38aeb2d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2273,41 +2273,69 @@ struct ReinterpretCastOpConstantFolder
LogicalResult matchAndRewrite(ReinterpretCastOp op,
PatternRewriter &rewriter) const override {
- unsigned srcStaticCount = llvm::count_if(
- llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
- op.getMixedStrides()),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+ MemRefType srcType = op.getType();
- SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ OpFoldResult offset = op.getConstifiedMixedOffset();
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
- // TODO: Using counting comparison instead of direct comparison because
- // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
- // IntegerAttrs, while constifyIndexValues (and therefore
- // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
- if (srcStaticCount ==
- llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
- return failure();
+ int64_t layoutOffset = ShapedType::kDynamic;
- // Do not fold if the offset is a negative constant; ViewLikeInterface
- // verifies that static offsets are non-negative.
- if (auto cst = getConstantIntValue(offsets[0]))
+ if (auto cst = getConstantIntValue(offset)) {
+ // If the offset is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original offset.
if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant offset is invalid");
+ offset = op.getMixedOffsets()[0];
+ else
+ layoutOffset = *cst;
+ }
- // Do not fold if any size is a negative constant; MemRefType::get asserts
- // non-negative static sizes.
- for (OpFoldResult sizeOfr : sizes)
- if (auto cst = getConstantIntValue(sizeOfr))
- if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant size is invalid");
+ int64_t lastStride = 1;
+ bool isContiguousMemrefType = (layoutOffset == 0);
+ SmallVector<int64_t> layoutStrides, shapes;
+
+ for (auto [stride, size, srcSize] :
+ llvm::zip(strides, sizes, op.getMixedSizes())) {
+ int64_t layoutStride = ShapedType::kDynamic;
+ if (auto cstStride = getConstantIntValue(stride)) {
+ layoutStride = *cstStride;
+ isContiguousMemrefType &= (layoutStride == lastStride);
+ }
+ layoutStrides.push_back(layoutStride);
+
+ int64_t layoutSize = ShapedType::kDynamic;
+ if (auto cstSize = getConstantIntValue(size)) {
+ // If the size is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original size.
+ if (*cstSize < 0)
+ size = srcSize;
+ else
+ layoutSize = *cstSize;
+ }
+ shapes.push_back(layoutSize);
+
+ if (ShapedType::isStatic(lastStride) && ShapedType::isStatic(layoutSize))
+ lastStride = lastStride * layoutSize;
+ else
+ lastStride = ShapedType::kDynamic;
+ }
+
+ MemRefType dstType = MemRefType::get(
+ shapes, srcType.getElementType(),
+ isContiguousMemrefType
+ ? nullptr
+ : StridedLayoutAttr::get(srcType.getContext(), layoutOffset,
+ layoutStrides),
+ srcType.getMemorySpace());
+
+ if (dstType == srcType)
+ return failure();
- auto newReinterpretCast = ReinterpretCastOp::create(
- rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+ auto newReinterpretCast =
+ ReinterpretCastOp::create(rewriter, op->getLoc(), dstType,
+ op.getSource(), offset, sizes, strides);
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..ca415ce4f0483 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1292,10 +1292,8 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
// which triggers an assertion in MemRefType::get (issue #188407).
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[SZ:.*]] = arith.constant -1 : index
-// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -1313,10 +1311,8 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
// ViewLikeInterface constraint that offsets must be non-negative.
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
-// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/189533
More information about the Mlir-commits
mailing list