[Mlir-commits] [mlir] 244af24 - [mlir][MemRef] Simplify extract_strided_metadata(reinterpret_cast)
Quentin Colombet
llvmlistbot at llvm.org
Mon Nov 14 10:42:07 PST 2022
Author: Quentin Colombet
Date: 2022-11-14T18:36:31Z
New Revision: 244af24faf3a2a674f38de7b085482e9f49d76fc
URL: https://github.com/llvm/llvm-project/commit/244af24faf3a2a674f38de7b085482e9f49d76fc
DIFF: https://github.com/llvm/llvm-project/commit/244af24faf3a2a674f38de7b085482e9f49d76fc.diff
LOG: [mlir][MemRef] Simplify extract_strided_metadata(reinterpret_cast)
This patch adds a pattern to simplify
```
base, offset, sizes, strides =
extract_strided_metadata(
reinterpret_cast(src, srcOffset, srcSizes, srcStrides))
```
Into
```
base, baseOffset, ... = extract_strided_metadata(src)
offset = srcOffset
sizes = srcSizes
strides = srcStrides
```
Note: Reinterpret_cast with unranked sources are not simplified since
they cannot feed extract_strided_metadata operations.
Differential Revision: https://reviews.llvm.org/D135837
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 1ebc2f60cf900..5a95e7ee668d6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -658,6 +658,72 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
}
};
+/// Replace `base, offset, sizes, strides =
+/// extract_strided_metadata(
+/// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
+/// With
+/// ```
+/// base, ... = extract_strided_metadata(src)
+/// offset = srcOffset
+/// sizes = srcSizes
+/// strides = srcStrides
+/// ```
+///
+/// In other words, consume the `reinterpret_cast` and apply its effects
+/// on the offset, sizes, and strides.
+class ExtractStridedMetadataOpReinterpretCastFolder
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+ PatternRewriter &rewriter) const override {
+ auto reinterpretCastOp = extractStridedMetadataOp.getSource()
+ .getDefiningOp<memref::ReinterpretCastOp>();
+ if (!reinterpretCastOp)
+ return failure();
+
+ Location loc = extractStridedMetadataOp.getLoc();
+ // Check if the source is suitable for extract_strided_metadata.
+ SmallVector<Type> inferredReturnTypes;
+ if (failed(extractStridedMetadataOp.inferReturnTypes(
+ rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
+ /*attributes=*/{}, /*regions=*/{}, inferredReturnTypes)))
+ return rewriter.notifyMatchFailure(
+ reinterpretCastOp, "reinterpret_cast source's type is incompatible");
+
+ auto memrefType =
+ reinterpretCastOp.getResult().getType().cast<MemRefType>();
+ unsigned rank = memrefType.getRank();
+ SmallVector<OpFoldResult> results;
+ results.resize_for_overwrite(rank * 2 + 2);
+
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, reinterpretCastOp.getSource());
+
+ // Register the base_buffer.
+ results[0] = newExtractStridedMetadata.getBaseBuffer();
+
+ // Register the new offset.
+ results[1] = getValueOrCreateConstantIndexOp(
+ rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
+
+ const unsigned sizeStartIdx = 2;
+ const unsigned strideStartIdx = sizeStartIdx + rank;
+
+ SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
+ for (unsigned i = 0; i < rank; ++i) {
+ results[sizeStartIdx + i] = sizes[i];
+ results[strideStartIdx + i] = strides[i];
+ }
+ rewriter.replaceOp(extractStridedMetadataOp,
+ getValueOrCreateConstantIndexOp(rewriter, loc, results));
+ return success();
+ }
+};
+
/// Replace `base, offset =
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
@@ -698,6 +764,7 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns(
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
+ ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 4648d9e4ec74d..b6661ee1b5dd5 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -866,3 +866,136 @@ func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref<i3
return %base2, %offset2 :
memref<i32>, index
}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of reinterpret_cast
+// when the source of the reinterpret_cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the reinterpret_cast.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast
+// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
+//
+// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
+func.func @extract_strided_metadata_of_reinterpret_cast(
+ %arg : memref<?x?xi32, strided<[?, ?], offset:?>>,
+ %offset: index,
+ %size0 : index, %size1 : index,
+ %stride0 : index, %stride1 : index)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.reinterpret_cast %arg to
+ offset: [%offset],
+ sizes: [%size0, %size1],
+ strides: [%stride0, %stride1] :
+ memref<?x?xi32, strided<[?, ?], offset: ?>> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// reinterpret_cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_unranked
+// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
+//
+// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[DYN_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]]], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]]]
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_reinterpret_cast_unranked(
+ %arg : memref<*xi32>,
+ %offset: index,
+ %size0 : index, %size1 : index,
+ %stride0 : index, %stride1 : index)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.reinterpret_cast %arg to
+ offset: [%offset],
+ sizes: [%size0, %size1],
+ strides: [%stride0, %stride1] :
+ memref<*xi32> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Similar to @extract_strided_metadata_of_reinterpret_cast, just make sure
+// we handle 0-D properly.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_rank0
+// CHECK-SAME: %[[ARG:.*]]: memref<i32, strided<[], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
+//
+// CHECK: %[[BASE:.*]], %[[BASE_OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
+func.func @extract_strided_metadata_of_reinterpret_cast_rank0(
+ %arg : memref<i32, strided<[], offset:?>>,
+ %offset: index,
+ %size0 : index, %size1 : index,
+ %stride0 : index, %stride1 : index)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.reinterpret_cast %arg to
+ offset: [%offset],
+ sizes: [%size0, %size1],
+ strides: [%stride0, %stride1] :
+ memref<i32, strided<[], offset: ?>> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
More information about the Mlir-commits
mailing list