[Mlir-commits] [mlir] dbe159b - [mlir] [IR] Allow zero strides in StridedLayoutAttr (#116463)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 20 22:17:31 PST 2024


Author: donald chen
Date: 2024-11-21T14:17:28+08:00
New Revision: dbe159b3f74ea41e16782fe5708756507d4a014f

URL: https://github.com/llvm/llvm-project/commit/dbe159b3f74ea41e16782fe5708756507d4a014f
DIFF: https://github.com/llvm/llvm-project/commit/dbe159b3f74ea41e16782fe5708756507d4a014f.diff

LOG: [mlir] [IR] Allow zero strides in StridedLayoutAttr (#116463)

Disabling memrefs with a stride of 0 was intended to prevent internal
aliasing, but this does not address all cases : internal aliasing can
still occur when the stride is less than the shape.

On the other hand, a stride of 0 can be very useful in certain
scenarios. For example, in architectures that support multi-dimensional
DMA, we can use memref::copy with a stride of 0 to achieve a broadcast
effect.

This commit removes the restriction that strides in memrefs cannot be 0.

Added: 
    

Modified: 
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/Dialect/Affine/memref-stride-calculation.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/IR/invalid-builtin-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8861a940336133..f288dd42baaa16 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
 LogicalResult
 StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                           int64_t offset, ArrayRef<int64_t> strides) {
-  if (llvm::is_contained(strides, 0))
-    return emitError() << "strides must not be zero";
-
   return success();
 }
 
@@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
   for (const auto &en : llvm::enumerate(strides)) {
     auto dim = en.index();
     auto stride = en.value();
-    assert(stride != 0 && "Invalid stride specification");
     auto d = getAffineDimExpr(dim, context);
     AffineExpr mult;
     // Static case.

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e8e8f3cdfbfd73..6546234429c8cb 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -798,20 +798,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
   for (auto &stride : strides)
     stride = simplifyAffineExpr(stride, numDims, numSymbols);
 
-  // In practice, a strided memref must be internally non-aliasing. Test
-  // against 0 as a proxy.
-  // TODO: static cases can have more advanced checks.
-  // TODO: dynamic cases would require a way to compare symbolic
-  // expressions and would probably need an affine set context propagated
-  // everywhere.
-  if (llvm::any_of(strides, [](AffineExpr e) {
-        return e == getAffineConstantExpr(0, e.getContext());
-      })) {
-    offset = AffineExpr();
-    strides.clear();
-    return failure();
-  }
-
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
index cce1946b391e7e..29a5f5e0d5f440 100644
--- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
+++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
@@ -51,9 +51,9 @@ func.func @f(%0: index) {
   %26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
 // CHECK: MemRefType offset: 0 strides: 1
   %27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
+// CHECK: MemRefType offset: ? strides: 0
   %28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
+// CHECK: MemRefType offset: 123 strides: 0
   %29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
 // CHECK: MemRefType offset: ? strides:
   %30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 51c4781c9022b2..f72ad48245f819 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
 
 // -----
 
-func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
-  // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
-  %out = memref.reinterpret_cast %in to
-           offset: [0], sizes: [9, 10], strides: [42, 1]
-         : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
-  return
-}
-
-// -----
-
 func.func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {
   // expected-error @+1 {{element types of source and destination memref types should be the same}}

diff  --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 07854a25000feb..51612446d2e6a6 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
 
 // -----
 
-// expected-error @below {{strides must not be zero}}
-func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
-
-// -----
-
 // expected-error @below {{expected the number of strides to match the rank}}
 func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
 


        


More information about the Mlir-commits mailing list