[Mlir-commits] [mlir] 244af24 - [mlir][MemRef] Simplify extract_strided_metadata(reinterpret_cast)

Quentin Colombet llvmlistbot at llvm.org
Mon Nov 14 10:42:07 PST 2022


Author: Quentin Colombet
Date: 2022-11-14T18:36:31Z
New Revision: 244af24faf3a2a674f38de7b085482e9f49d76fc

URL: https://github.com/llvm/llvm-project/commit/244af24faf3a2a674f38de7b085482e9f49d76fc
DIFF: https://github.com/llvm/llvm-project/commit/244af24faf3a2a674f38de7b085482e9f49d76fc.diff

LOG: [mlir][MemRef] Simplify extract_strided_metadata(reinterpret_cast)

This patch adds a pattern to simplify
```
base, offset, sizes, strides =
  extract_strided_metadata(
    reinterpret_cast(src, srcOffset, srcSizes, srcStrides))
```

Into
```
base, baseOffset, ... = extract_strided_metadata(src)
offset = srcOffset
sizes = srcSizes
strides = srcStrides
```

Note: Reinterpret_cast with unranked sources are not simplified since
they cannot feed extract_strided_metadata operations.

Differential Revision: https://reviews.llvm.org/D135837

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 1ebc2f60cf900..5a95e7ee668d6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -658,6 +658,72 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
   }
 };
 
+/// Replace `base, offset, sizes, strides =
+///              extract_strided_metadata(
+///                 reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
+/// With
+/// ```
+/// base, ... = extract_strided_metadata(src)
+/// offset = srcOffset
+/// sizes = srcSizes
+/// strides = srcStrides
+/// ```
+///
+/// In other words, consume the `reinterpret_cast` and apply its effects
+/// on the offset, sizes, and strides.
+class ExtractStridedMetadataOpReinterpretCastFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  PatternRewriter &rewriter) const override {
+    auto reinterpretCastOp = extractStridedMetadataOp.getSource()
+                                 .getDefiningOp<memref::ReinterpretCastOp>();
+    if (!reinterpretCastOp)
+      return failure();
+
+    Location loc = extractStridedMetadataOp.getLoc();
+    // Check if the source is suitable for extract_strided_metadata.
+    SmallVector<Type> inferredReturnTypes;
+    if (failed(extractStridedMetadataOp.inferReturnTypes(
+            rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
+            /*attributes=*/{}, /*regions=*/{}, inferredReturnTypes)))
+      return rewriter.notifyMatchFailure(
+          reinterpretCastOp, "reinterpret_cast source's type is incompatible");
+
+    auto memrefType =
+        reinterpretCastOp.getResult().getType().cast<MemRefType>();
+    unsigned rank = memrefType.getRank();
+    SmallVector<OpFoldResult> results;
+    results.resize_for_overwrite(rank * 2 + 2);
+
+    auto newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(
+            loc, reinterpretCastOp.getSource());
+
+    // Register the base_buffer.
+    results[0] = newExtractStridedMetadata.getBaseBuffer();
+
+    // Register the new offset.
+    results[1] = getValueOrCreateConstantIndexOp(
+        rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
+
+    const unsigned sizeStartIdx = 2;
+    const unsigned strideStartIdx = sizeStartIdx + rank;
+
+    SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
+    for (unsigned i = 0; i < rank; ++i) {
+      results[sizeStartIdx + i] = sizes[i];
+      results[strideStartIdx + i] = strides[i];
+    }
+    rewriter.replaceOp(extractStridedMetadataOp,
+                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
+    return success();
+  }
+};
+
 /// Replace `base, offset =
 ///            extract_strided_metadata(extract_strided_metadata(src)#0)`
 /// With
@@ -698,6 +764,7 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns(
            ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
            ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
            RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
+           ExtractStridedMetadataOpReinterpretCastFolder,
            ExtractStridedMetadataOpExtractStridedMetadataFolder>(
           patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 4648d9e4ec74d..b6661ee1b5dd5 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -866,3 +866,136 @@ func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref<i3
   return %base2, %offset2 :
       memref<i32>, index
 }
+
+// -----
+
+// Check that we simplify extract_strided_metadata of reinterpret_cast
+// when the source of the reinterpret_cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the reinterpret_cast.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast
+//  CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
+//
+//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
+func.func @extract_strided_metadata_of_reinterpret_cast(
+  %arg : memref<?x?xi32, strided<[?, ?], offset:?>>,
+  %offset: index,
+  %size0 : index, %size1 : index,
+  %stride0 : index, %stride1 : index)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.reinterpret_cast %arg to
+      offset: [%offset],
+      sizes: [%size0, %size1],
+      strides: [%stride0, %stride1] :
+      memref<?x?xi32, strided<[?, ?], offset: ?>> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}
+
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// reinterpret_cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_unranked
+//  CHECK-SAME: %[[ARG:.*]]: memref<*xi32>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
+//
+//       CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[DYN_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]]], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]]]
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+//       CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_reinterpret_cast_unranked(
+  %arg : memref<*xi32>,
+  %offset: index,
+  %size0 : index, %size1 : index,
+  %stride0 : index, %stride1 : index)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.reinterpret_cast %arg to
+      offset: [%offset],
+      sizes: [%size0, %size1],
+      strides: [%stride0, %stride1] :
+      memref<*xi32> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}
+
+// -----
+
+// Similar to @extract_strided_metadata_of_reinterpret_cast, just make sure
+// we handle 0-D properly.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_rank0
+//  CHECK-SAME: %[[ARG:.*]]: memref<i32, strided<[], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
+//
+//       CHECK: %[[BASE:.*]], %[[BASE_OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
+func.func @extract_strided_metadata_of_reinterpret_cast_rank0(
+  %arg : memref<i32, strided<[], offset:?>>,
+  %offset: index,
+  %size0 : index, %size1 : index,
+  %stride0 : index, %stride1 : index)
+  -> (memref<i32>, index,
+      index, index,
+      index, index) {
+
+  %cast =
+    memref.reinterpret_cast %arg to
+      offset: [%offset],
+      sizes: [%size0, %size1],
+      strides: [%stride0, %stride1] :
+      memref<i32, strided<[], offset: ?>> to
+      memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+  %base, %base_offset, %sizes:2, %strides:2 =
+    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+    -> memref<i32>, index,
+       index, index,
+       index, index
+
+  return %base, %base_offset,
+    %sizes#0, %sizes#1,
+    %strides#0, %strides#1 :
+      memref<i32>, index,
+      index, index,
+      index, index
+}


        


More information about the Mlir-commits mailing list