[Mlir-commits] [mlir] [mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… (PR #68291)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 5 01:58:42 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

…st)`

`expand-strided-metadata` was missing a pattern to get rid of `memref.cast`.
The pattern is straight foward:
Produce a new `extract_strided_metadata` with the source of the cast and fold the static information (sizes, strides, offset) along the way.

---
Full diff: https://github.com/llvm/llvm-project/pull/68291.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+88) 
- (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+124) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 672ef3eb4cd50fa..4f3fa6a5ed245f8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
+/// Replace `base, offset, sizes, strides =
+///              extract_strided_metadata(
+///                 cast(src) to dstTy)`
+/// With
+/// ```
+/// base, ... = extract_strided_metadata(src)
+/// offset = !dstTy.srcOffset.isDynamic()?
+///            dstTy.srcOffset :
+///            extract_strided_metadata(src).offset
+/// sizes = for each srcSize in dstTy.srcSizes:
+///           !srcSize.isDynamic()
+///             ? srcSize
+//              : extract_strided_metadata(src).sizes[i]
+/// strides = for each srcStride in dstTy.srcStrides:
+///             !srcStrides.isDynamic()
+///               ? srcStrides
+///               : extract_strided_metadata(src).strides[i]
+/// ```
+///
+/// In other words, consume the `cast` and apply its effects
+/// on the offset, sizes, and strides or compute them directly from `src`.
+class ExtractStridedMetadataOpCastFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  PatternRewriter &rewriter) const override {
+    Value source = extractStridedMetadataOp.getSource();
+    auto castOp = source.getDefiningOp<memref::CastOp>();
+    if (!castOp)
+      return failure();
+
+    Location loc = extractStridedMetadataOp.getLoc();
+    // Check if the source is suitable for extract_strided_metadata.
+    SmallVector<Type> inferredReturnTypes;
+    if (failed(extractStridedMetadataOp.inferReturnTypes(
+            rewriter.getContext(), loc, {castOp.getSource()},
+            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
+            inferredReturnTypes)))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "cast source's type is incompatible");
+
+    auto memrefType = cast<MemRefType>(source.getType());
+    unsigned rank = memrefType.getRank();
+    SmallVector<OpFoldResult> results;
+    results.resize_for_overwrite(rank * 2 + 2);
+
+    auto newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc,
+                                                          castOp.getSource());
+
+    // Register the base_buffer.
+    results[0] = newExtractStridedMetadata.getBaseBuffer();
+
+    auto getConstantOrValue = [&rewriter](int64_t constant,
+                                          OpFoldResult ofr) -> OpFoldResult {
+      return !ShapedType::isDynamic(constant)
+                 ? OpFoldResult(rewriter.getIndexAttr(constant))
+                 : ofr;
+    };
+
+    auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
+    assert(sourceStrides.size() == rank && "unexpected number of strides");
+
+    // Register the new offset.
+    results[1] =
+        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
+
+    const unsigned sizeStartIdx = 2;
+    const unsigned strideStartIdx = sizeStartIdx + rank;
+    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
+
+    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
+    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
+    for (unsigned i = 0; i < rank; ++i) {
+      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
+      results[strideStartIdx + i] =
+          getConstantOrValue(sourceStrides[i], strides[i]);
+    }
+    rewriter.replaceOp(extractStridedMetadataOp,
+                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
+    return success();
+  }
+};
+
 /// Replace `base, offset =
 ///            extract_strided_metadata(extract_strided_metadata(src)#0)`
 /// With
@@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
                ExtractStridedMetadataOpGetGlobalFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
@@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index a6303aa2d971106..4efb38abcd7679c 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1369,3 +1369,127 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
   return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
       memref<i32>, index, index, index, index, index
 }
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the cast.
+// Additionally the folder on extract_strided_metadata should propagate the
+// static information.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[?, ?], offset: ?>>)
+//
+//       CHECK: %[[C3:.*]] = arith.constant 3 : index
+//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[DYN_STRIDES]]#1
+func.func @extract_strided_metadata_of_cast(
+  %arg : memref<3x?xi32, strided<[?, ?], offset:?>>)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.cast %arg :
+      memref<3x?xi32, strided<[?, ?], offset: ?>> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// Same as extract_strided_metadata_of_cast but with constant sizes and strides
+// in the destination type.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
+//  CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+//
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+//   CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
+//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
+func.func @extract_strided_metadata_of_cast_w_csts(
+  %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.cast %arg :
+      memref<?x?xi32, strided<[?, ?], offset: ?>> to
+      memref<4x?xi32, strided<[?, 18], offset: 25>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
+//  CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
+//
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+//       CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_cast_unranked(
+  %arg : memref<*xi32>)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.cast %arg :
+      memref<*xi32> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68291


More information about the Mlir-commits mailing list