[Mlir-commits] [mlir] [MLIR] normalize-memrefs: skip offset/non-contiguous layouts (PR #174787)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 7 08:15:11 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Muhammad Abdul  (0xzre)

<details>
<summary>Changes</summary>

Resolves #<!-- -->174540 

fix where it would pad slices beyond their real extent
- normalize-memrefs: skip layouts with offsets or non-contiguous strides so we don’t change the logical domain. keeps memrefs correct for offset/linearized cases. 
- Tests updated: added regression for gap slices.

---
Full diff: https://github.com/llvm/llvm-project/pull/174787.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+13) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir (+12-4) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+7-7) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index deba1600e28a0..ed8d049f00d43 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1868,6 +1868,19 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
     // a trivial (identity) map.
     return memrefType;
   }
+
+  // Only normalize strided layouts that map to a contiguous, zero-based
+  // region. Memrefs with non-zero (or dynamic) offsets or non-contiguous
+  // strides would require expanding the logical domain, which is not safe
+  // without introducing copies.
+  int64_t offset = ShapedType::kDynamic;
+  SmallVector<int64_t, 4> strides;
+  if (succeeded(memrefType.getStridesAndOffset(strides, offset))) {
+    if (ShapedType::isDynamic(offset) || offset != 0)
+      return memrefType;
+    if (!memrefType.areTrailingDimsContiguous(rank))
+      return memrefType;
+  }
   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
   unsigned numSymbolicOperands = layoutMap.getNumSymbols();
 
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 344da4e5e2462..0006cb59ed0fe 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -191,8 +191,16 @@ func.func @reinterpret_cast_non_zero_offset(%arg0: index, %arg1: memref<1x10x17x
   %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xf32>
   cf.br ^bb3
 ^bb3:  // pred: ^bb1
-  // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [32], strides: [1] : memref<2x17xf32> to memref<32xf32>
-  // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<32xf32>, memref<32xf32>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
-  %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
-  return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+  // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
+  // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+    %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
+    return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
 }
+
+  // CHECK-LABEL: strided_gap_skip
+  func.func @strided_gap_skip(%arg0: memref<11x13xf32>) -> memref<2x1xf32, strided<[13, 1], offset: 15>> {
+    // CHECK: %[[VIEW:.*]] = memref.reinterpret_cast %{{.*}} to offset: [15], sizes: [2, 1], strides: [13, 1] : memref<11x13xf32> to memref<2x1xf32, strided<[13, 1], offset: 15>>
+    // CHECK: return %[[VIEW]] : memref<2x1xf32, strided<[13, 1], offset: 15>>
+    %view = memref.reinterpret_cast %arg0 to offset: [15], sizes: [2, 1], strides: [13, 1] : memref<11x13xf32> to memref<2x1xf32, strided<[13, 1], offset: 15>>
+    return %view : memref<2x1xf32, strided<[13, 1], offset: 15>>
+  }
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index d2924fb1ecf77..b7e303cd5f664 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -33,28 +33,28 @@ func.func @permute() {
 
 // CHECK-LABEL: func @alloca
 func.func @alloca(%idx : index) {
-  // CHECK-NEXT: memref.alloca() : memref<65xf32>
+  // CHECK-NEXT: memref.alloca() : memref<64xf32, #map>
   %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
-  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+  // CHECK-NEXT: affine.load %{{.*}}[%arg0] : memref<64xf32, #map>
   affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
   affine.for %i = 0 to 64 {
     %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
     "prevent.dce"(%1) : (f32) -> ()
-    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}}] : memref<64xf32, #map>
   }
   return
 }
 
 // CHECK-LABEL: func @shift
 func.func @shift(%idx : index) {
-  // CHECK-NEXT: memref.alloc() : memref<65xf32>
+  // CHECK-NEXT: memref.alloc() : memref<64xf32, #map>
   %A = memref.alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
-  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+  // CHECK-NEXT: affine.load %{{.*}}[%arg0] : memref<64xf32, #map>
   affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
   affine.for %i = 0 to 64 {
     %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
     "prevent.dce"(%1) : (f32) -> ()
-    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}}] : memref<64xf32, #map>
   }
   return
 }
@@ -120,7 +120,7 @@ func.func @strided_cumulative() {
   affine.for %i = 0 to 2 {
     // CHECK: affine.for %[[IV1:.*]] =
     affine.for %j = 0 to 5 {
-      // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32>
+      // CHECK: affine.load %{{.*}}[%[[IV0]], %[[IV1]]] : memref<2x5xf32, #map3>
       %1 = affine.load %A[%i, %j]  : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
       "prevent.dce"(%1) : (f32) -> ()
     }

``````````

</details>


https://github.com/llvm/llvm-project/pull/174787


More information about the Mlir-commits mailing list