[Mlir-commits] [mlir] [memref] Handle edge case in subview of full static size fold (PR #105635)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 22 03:09:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds:
```mlir
func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
%0 = memref.subview %arg0[%idx, 0][16, 4][1, 1]
: memref<16x4xf32, strided<[4, 1], offset: ?>>
to memref<16x4xf32, strided<[4, 1], offset: ?>>
return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
}
```
To:
```mlir
func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
}
```
Which drops the dynamic offset from the `subview` op.
---
Full diff: https://github.com/llvm/llvm-project/pull/105635.diff
4 Files Affected:
- (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+4)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+9-6)
- (modified) mlir/lib/IR/BuiltinAttributes.cpp (+7)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+13)
``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97bd..f0d41754001400 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
let extraClassDeclaration = [{
/// Print the attribute to the given output stream.
void print(raw_ostream &os) const;
+
+ /// Returns true if this layout is static, i.e. the strides and offset all
+ /// have a known value > 0.
+ bool hasStaticLayout() const;
}];
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 150049e5c5effe..9c021d3613f1c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
- auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
- auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
-
- if (resultShapedType.hasStaticShape() &&
- resultShapedType == sourceShapedType) {
+ MemRefType sourceMemrefType = getSource().getType();
+ MemRefType resultMemrefType = getResult().getType();
+ auto resultLayout =
+ dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
+
+ if (resultMemrefType == sourceMemrefType &&
+ resultMemrefType.hasStaticShape() &&
+ (!resultLayout || resultLayout.hasStaticLayout())) {
return getViewSource();
}
@@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
- resultShapedType == sourceShapedType)
+ resultMemrefType == sourceMemrefType)
return getViewSource();
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067..8861a940336133 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
os << ">";
}
+/// Returns true if this layout is static, i.e. the strides and offset all have
+/// a known value > 0.
+bool StridedLayoutAttr::hasStaticLayout() const {
+ return !ShapedType::isDynamic(getOffset()) &&
+ !ShapedType::isDynamicShape(getStrides());
+}
+
/// Returns the strided layout as an affine map.
AffineMap StridedLayoutAttr::getAffineMap() const {
return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b15af9baca7dc7..02110bc2892d05 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
// -----
+// CHECK-LABEL: func @negative_subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
+// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>>
+func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> {
+ %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>>
+ return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
+}
+
+// -----
+
func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
{
``````````
</details>
https://github.com/llvm/llvm-project/pull/105635
More information about the Mlir-commits
mailing list