[Mlir-commits] [mlir] [memref] Handle edge case in subview of full static size fold (PR #105635)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 22 03:09:09 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds:

```mlir
func.func @<!-- -->subview_of_static_full_size(
  %arg0:  memref<16x4xf32,  strided<[4, 1], offset: ?>>, %idx: index)
  -> memref<16x4xf32,  strided<[4, 1], offset: ?>>
{
  %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1]
   : memref<16x4xf32,  strided<[4, 1], offset: ?>>
     to memref<16x4xf32,  strided<[4, 1], offset: ?>>
  return %0 : memref<16x4xf32,  strided<[4, 1], offset: ?>>
}
```

To:

```mlir
func.func @<!-- -->subview_of_static_full_size(
  %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index)
  -> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
  return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
}
```

Which drops the dynamic offset from the `subview` op.

---
Full diff: https://github.com/llvm/llvm-project/pull/105635.diff


4 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+4) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+9-6) 
- (modified) mlir/lib/IR/BuiltinAttributes.cpp (+7) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+13) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97bd..f0d41754001400 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
   let extraClassDeclaration = [{
     /// Print the attribute to the given output stream.
     void print(raw_ostream &os) const;
+
+    /// Returns true if this layout is static, i.e. the strides and offset all
+    /// have a known value > 0.
+    bool hasStaticLayout() const;
   }];
 }
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 150049e5c5effe..9c021d3613f1c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
-  auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
-  auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
-
-  if (resultShapedType.hasStaticShape() &&
-      resultShapedType == sourceShapedType) {
+  MemRefType sourceMemrefType = getSource().getType();
+  MemRefType resultMemrefType = getResult().getType();
+  auto resultLayout =
+      dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
+
+  if (resultMemrefType == sourceMemrefType &&
+      resultMemrefType.hasStaticShape() &&
+      (!resultLayout || resultLayout.hasStaticLayout())) {
     return getViewSource();
   }
 
@@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
         strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
     bool allSizesSame = llvm::equal(sizes, srcSizes);
     if (allOffsetsZero && allStridesOne && allSizesSame &&
-        resultShapedType == sourceShapedType)
+        resultMemrefType == sourceMemrefType)
       return getViewSource();
   }
 
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067..8861a940336133 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
   os << ">";
 }
 
+/// Returns true if this layout is static, i.e. the strides and offset all have
+/// a known value > 0.
+bool StridedLayoutAttr::hasStaticLayout() const {
+  return !ShapedType::isDynamic(getOffset()) &&
+         !ShapedType::isDynamicShape(getStrides());
+}
+
 /// Returns the strided layout as an affine map.
 AffineMap StridedLayoutAttr::getAffineMap() const {
   return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b15af9baca7dc7..02110bc2892d05 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
 
 // -----
 
+// CHECK-LABEL: func @negative_subview_of_static_full_size
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<16x4xf32,  strided<[4, 1], offset: ?>>
+//  CHECK-SAME:   %[[IDX:.+]]: index
+//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
+//  CHECK-SAME:                    to memref<16x4xf32,  strided<[4, 1], offset: ?>>
+//       CHECK:    return %[[S]] : memref<16x4xf32,  strided<[4, 1], offset: ?>>
+func.func @negative_subview_of_static_full_size(%arg0:  memref<16x4xf32,  strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32,  strided<[4, 1], offset: ?>> {
+  %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32,  strided<[4, 1], offset: ?>> to memref<16x4xf32,  strided<[4, 1], offset: ?>>
+  return %0 : memref<16x4xf32,  strided<[4, 1], offset: ?>>
+}
+
+// -----
+
 func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
     %arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
 {

``````````

</details>


https://github.com/llvm/llvm-project/pull/105635


More information about the Mlir-commits mailing list