[Mlir-commits] [mlir] [mlir] [IR] Allow zero strides in StridedLayoutAttr (PR #116463)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 15 21:48:30 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape.

On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect.

This commit removes the restriction that strides in memrefs cannot be 0.

---
Full diff: https://github.com/llvm/llvm-project/pull/116463.diff


5 Files Affected:

- (modified) mlir/lib/IR/BuiltinAttributes.cpp (-4) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (-14) 
- (modified) mlir/test/Dialect/Affine/memref-stride-calculation.mlir (+2-2) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (-10) 
- (modified) mlir/test/IR/invalid-builtin-types.mlir (-5) 


``````````diff
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8861a940336133..f288dd42baaa16 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
 LogicalResult
 StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                           int64_t offset, ArrayRef<int64_t> strides) {
-  if (llvm::is_contained(strides, 0))
-    return emitError() << "strides must not be zero";
-
   return success();
 }
 
@@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
   for (const auto &en : llvm::enumerate(strides)) {
     auto dim = en.index();
     auto stride = en.value();
-    assert(stride != 0 && "Invalid stride specification");
     auto d = getAffineDimExpr(dim, context);
     AffineExpr mult;
     // Static case.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..c28c580690166f 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -803,20 +803,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
   for (auto &stride : strides)
     stride = simplifyAffineExpr(stride, numDims, numSymbols);
 
-  // In practice, a strided memref must be internally non-aliasing. Test
-  // against 0 as a proxy.
-  // TODO: static cases can have more advanced checks.
-  // TODO: dynamic cases would require a way to compare symbolic
-  // expressions and would probably need an affine set context propagated
-  // everywhere.
-  if (llvm::any_of(strides, [](AffineExpr e) {
-        return e == getAffineConstantExpr(0, e.getContext());
-      })) {
-    offset = AffineExpr();
-    strides.clear();
-    return failure();
-  }
-
   return success();
 }
 
diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
index cce1946b391e7e..29a5f5e0d5f440 100644
--- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
+++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
@@ -51,9 +51,9 @@ func.func @f(%0: index) {
   %26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
 // CHECK: MemRefType offset: 0 strides: 1
   %27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
+// CHECK: MemRefType offset: ? strides: 0
   %28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
+// CHECK: MemRefType offset: 123 strides: 0
   %29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
 // CHECK: MemRefType offset: ? strides:
   %30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 51c4781c9022b2..f72ad48245f819 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
 
 // -----
 
-func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
-  // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
-  %out = memref.reinterpret_cast %in to
-           offset: [0], sizes: [9, 10], strides: [42, 1]
-         : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
-  return
-}
-
-// -----
-
 func.func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 07854a25000feb..51612446d2e6a6 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
 
 // -----
 
-// expected-error @below {{strides must not be zero}}
-func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
-
-// -----
-
 // expected-error @below {{expected the number of strides to match the rank}}
 func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/116463


More information about the Mlir-commits mailing list