[Mlir-commits] [mlir] [mlir] [IR] Allow zero strides in StridedLayoutAttr (PR #116463)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 15 21:48:30 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: donald chen (cxy-1993)
<details>
<summary>Changes</summary>
Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape.
On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect.
This commit removes the restriction that strides in memrefs cannot be 0.
---
Full diff: https://github.com/llvm/llvm-project/pull/116463.diff
5 Files Affected:
- (modified) mlir/lib/IR/BuiltinAttributes.cpp (-4)
- (modified) mlir/lib/IR/BuiltinTypes.cpp (-14)
- (modified) mlir/test/Dialect/Affine/memref-stride-calculation.mlir (+2-2)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (-10)
- (modified) mlir/test/IR/invalid-builtin-types.mlir (-5)
``````````diff
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8861a940336133..f288dd42baaa16 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
LogicalResult
StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t offset, ArrayRef<int64_t> strides) {
- if (llvm::is_contained(strides, 0))
- return emitError() << "strides must not be zero";
-
return success();
}
@@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
for (const auto &en : llvm::enumerate(strides)) {
auto dim = en.index();
auto stride = en.value();
- assert(stride != 0 && "Invalid stride specification");
auto d = getAffineDimExpr(dim, context);
AffineExpr mult;
// Static case.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..c28c580690166f 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -803,20 +803,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
for (auto &stride : strides)
stride = simplifyAffineExpr(stride, numDims, numSymbols);
- // In practice, a strided memref must be internally non-aliasing. Test
- // against 0 as a proxy.
- // TODO: static cases can have more advanced checks.
- // TODO: dynamic cases would require a way to compare symbolic
- // expressions and would probably need an affine set context propagated
- // everywhere.
- if (llvm::any_of(strides, [](AffineExpr e) {
- return e == getAffineConstantExpr(0, e.getContext());
- })) {
- offset = AffineExpr();
- strides.clear();
- return failure();
- }
-
return success();
}
diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
index cce1946b391e7e..29a5f5e0d5f440 100644
--- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
+++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir
@@ -51,9 +51,9 @@ func.func @f(%0: index) {
%26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
// CHECK: MemRefType offset: 0 strides: 1
%27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
+// CHECK: MemRefType offset: ? strides: 0
%28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
-// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
+// CHECK: MemRefType offset: 123 strides: 0
%29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
// CHECK: MemRefType offset: ? strides:
%30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 51c4781c9022b2..f72ad48245f819 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
// -----
-func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
- // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
- %out = memref.reinterpret_cast %in to
- offset: [0], sizes: [9, 10], strides: [42, 1]
- : memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
- return
-}
-
-// -----
-
func.func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 07854a25000feb..51612446d2e6a6 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
// -----
-// expected-error @below {{strides must not be zero}}
-func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
-
-// -----
-
// expected-error @below {{expected the number of strides to match the rank}}
func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
``````````
</details>
https://github.com/llvm/llvm-project/pull/116463
More information about the Mlir-commits
mailing list