[Mlir-commits] [mlir] [MLIR] normalize-memrefs: skip offset/non-contiguous layouts (PR #174787)
Muhammad Abdul
llvmlistbot at llvm.org
Wed Jan 7 07:26:24 PST 2026
https://github.com/0xzre created https://github.com/llvm/llvm-project/pull/174787
Resolves #174540
fix where it would pad slices beyond their real extent
- normalize-memrefs: skip layouts with offsets or non-contiguous strides so we don’t change the logical domain. keeps memrefs correct for offset/linearized cases.
- Tests updated: added regression for gap slices.
>From aabc31922d7f6ccfc6f9e03eca8bc65f127d128e Mon Sep 17 00:00:00 2001
From: 0xzre <alilo.ghazali at gmail.com>
Date: Wed, 7 Jan 2026 21:41:34 +0700
Subject: [PATCH] [MLIR] normalize-memrefs: skip offset/non-contiguous layouts
---
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 13 +++++++++++++
.../Dialect/MemRef/normalize-memrefs-ops.mlir | 16 ++++++++++++----
mlir/test/Dialect/MemRef/normalize-memrefs.mlir | 14 +++++++-------
3 files changed, 32 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index deba1600e28a0..ed8d049f00d43 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1868,6 +1868,19 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
// a trivial (identity) map.
return memrefType;
}
+
+ // Only normalize strided layouts that map to a contiguous, zero-based
+ // region. Memrefs with non-zero (or dynamic) offsets or non-contiguous
+ // strides would require expanding the logical domain, which is not safe
+ // without introducing copies.
+ int64_t offset = ShapedType::kDynamic;
+ SmallVector<int64_t, 4> strides;
+ if (succeeded(memrefType.getStridesAndOffset(strides, offset))) {
+ if (ShapedType::isDynamic(offset) || offset != 0)
+ return memrefType;
+ if (!memrefType.areTrailingDimsContiguous(rank))
+ return memrefType;
+ }
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
unsigned numSymbolicOperands = layoutMap.getNumSymbols();
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 344da4e5e2462..0006cb59ed0fe 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -191,8 +191,16 @@ func.func @reinterpret_cast_non_zero_offset(%arg0: index, %arg1: memref<1x10x17x
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xf32>
cf.br ^bb3
^bb3: // pred: ^bb1
- // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [32], strides: [1] : memref<2x17xf32> to memref<32xf32>
- // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<32xf32>, memref<32xf32>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
- %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
- return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+ // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
+ // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+ %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
+ return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
}
+
+ // CHECK-LABEL: strided_gap_skip
+ func.func @strided_gap_skip(%arg0: memref<11x13xf32>) -> memref<2x1xf32, strided<[13, 1], offset: 15>> {
+ // CHECK: %[[VIEW:.*]] = memref.reinterpret_cast %{{.*}} to offset: [15], sizes: [2, 1], strides: [13, 1] : memref<11x13xf32> to memref<2x1xf32, strided<[13, 1], offset: 15>>
+ // CHECK: return %[[VIEW]] : memref<2x1xf32, strided<[13, 1], offset: 15>>
+ %view = memref.reinterpret_cast %arg0 to offset: [15], sizes: [2, 1], strides: [13, 1] : memref<11x13xf32> to memref<2x1xf32, strided<[13, 1], offset: 15>>
+ return %view : memref<2x1xf32, strided<[13, 1], offset: 15>>
+ }
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index d2924fb1ecf77..b7e303cd5f664 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -33,28 +33,28 @@ func.func @permute() {
// CHECK-LABEL: func @alloca
func.func @alloca(%idx : index) {
- // CHECK-NEXT: memref.alloca() : memref<65xf32>
+ // CHECK-NEXT: memref.alloca() : memref<64xf32, #map>
%A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
- // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[%arg0] : memref<64xf32, #map>
affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
affine.for %i = 0 to 64 {
%1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
"prevent.dce"(%1) : (f32) -> ()
- // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+ // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}}] : memref<64xf32, #map>
}
return
}
// CHECK-LABEL: func @shift
func.func @shift(%idx : index) {
- // CHECK-NEXT: memref.alloc() : memref<65xf32>
+ // CHECK-NEXT: memref.alloc() : memref<64xf32, #map>
%A = memref.alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
- // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[%arg0] : memref<64xf32, #map>
affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
affine.for %i = 0 to 64 {
%1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
"prevent.dce"(%1) : (f32) -> ()
- // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+ // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}}] : memref<64xf32, #map>
}
return
}
@@ -120,7 +120,7 @@ func.func @strided_cumulative() {
affine.for %i = 0 to 2 {
// CHECK: affine.for %[[IV1:.*]] =
affine.for %j = 0 to 5 {
- // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32>
+ // CHECK: affine.load %{{.*}}[%[[IV0]], %[[IV1]]] : memref<2x5xf32, #map3>
%1 = affine.load %A[%i, %j] : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
"prevent.dce"(%1) : (f32) -> ()
}
More information about the Mlir-commits
mailing list