[Mlir-commits] [mlir] 98c5296 - [mlir][MemRef] Move the forwarding patterns for `extract_strided_metadata`
Quentin Colombet
llvmlistbot at llvm.org
Tue Oct 18 15:35:16 PDT 2022
Author: Quentin Colombet
Date: 2022-10-18T22:34:50Z
New Revision: 98c529652af413eba8df34642613ca5a0e87e52c
URL: https://github.com/llvm/llvm-project/commit/98c529652af413eba8df34642613ca5a0e87e52c
DIFF: https://github.com/llvm/llvm-project/commit/98c529652af413eba8df34642613ca5a0e87e52c.diff
LOG: [mlir][MemRef] Move the forwarding patterns for `extract_strided_metadata`
The `SimplifyExtractStridedMetadata` pass features a pattern that forward
statically known information (offset, sizes, strides) to their respective
users.
This patch moves this pattern from this pass to the
`extract_strided_metadata` folding patterns.
Differential Revision: https://reviews.llvm.org/D135797
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ba8fe8103269c..1f1b118087f90 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -912,6 +912,8 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
let assemblyFormat = [{
$source `:` type($source) `->` type(results) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ab7311b3d101f..9a6727dff9335 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1286,6 +1286,58 @@ void ExtractStridedMetadataOp::getAsmResultNames(
}
}
+/// Helper function to perform the replacement of all constant uses of `values`
+/// by a materialized constant extracted from `maybeConstants`.
+/// `values` and `maybeConstants` are expected to have the same size.
+template <typename Container>
+static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
+ Container values,
+ ArrayRef<int64_t> maybeConstants,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ assert(values.size() == maybeConstants.size() &&
+ " expected values and maybeConstants of the same size");
+ bool atLeastOneReplacement = false;
+ for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
+ // Don't materialize a constant if there are no uses: this would indice
+ // infinite loops in the driver.
+ if (isDynamic(maybeConstant) || result.use_empty())
+ continue;
+ Value constantVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
+ for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
+ // updateRootInplace: lambda cannot capture structured bindings in C++17
+ // yet.
+ op->replaceUsesOfWith(result, constantVal);
+ atLeastOneReplacement = true;
+ }
+ }
+ return atLeastOneReplacement;
+}
+
+LogicalResult
+ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ OpBuilder builder(*this);
+ auto memrefType = getSource().getType().cast<MemRefType>();
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
+ (void)res;
+ assert(succeeded(res) && "must be a strided memref type");
+
+ bool atLeastOneReplacement = replaceConstantUsesOf(
+ builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
+ ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
+ atLeastOneReplacement |=
+ replaceConstantUsesOf(builder, getLoc(), getSizes(),
+ memrefType.getShape(), ShapedType::isDynamic);
+ atLeastOneReplacement |=
+ replaceConstantUsesOf(builder, getLoc(), getStrides(), strides,
+ ShapedType::isDynamicStrideOrOffset);
+
+ return success(atLeastOneReplacement);
+}
+
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 6e861032d35bc..1ebc2f60cf900 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -550,64 +550,6 @@ struct ExtractStridedMetadataOpReshapeFolder
}
};
-/// Helper function to perform the replacement of all constant uses of `values`
-/// by a materialized constant extracted from `maybeConstants`.
-/// `values` and `maybeConstants` are expected to have the same size.
-template <typename Container>
-bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc,
- Container values, ArrayRef<int64_t> maybeConstants,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- assert(values.size() == maybeConstants.size() &&
- " expected values and maybeConstants of the same size");
- bool atLeastOneReplacement = false;
- for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
- // Don't materialize a constant if there are no uses: this would indice
- // infinite loops in the driver.
- if (isDynamic(maybeConstant) || result.use_empty())
- continue;
- Value constantVal =
- rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
- for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
- rewriter.startRootUpdate(op);
- // updateRootInplace: lambda cannot capture structured bindings in C++17
- // yet.
- op->replaceUsesOfWith(result, constantVal);
- rewriter.finalizeRootUpdate(op);
- atLeastOneReplacement = true;
- }
- }
- return atLeastOneReplacement;
-}
-
-// Forward propagate all constants information from an ExtractStridedMetadataOp.
-struct ForwardStaticMetadata
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
- using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
- PatternRewriter &rewriter) const override {
- auto memrefType = metadataOp.getSource().getType().cast<MemRefType>();
- SmallVector<int64_t> strides;
- int64_t offset;
- LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
- (void)res;
- assert(succeeded(res) && "must be a strided memref type");
-
- bool atLeastOneReplacement = replaceConstantUsesOf(
- rewriter, metadataOp.getLoc(),
- ArrayRef<TypedValue<IndexType>>(metadataOp.getOffset()),
- ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
- atLeastOneReplacement |= replaceConstantUsesOf(
- rewriter, metadataOp.getLoc(), metadataOp.getSizes(),
- memrefType.getShape(), ShapedType::isDynamic);
- atLeastOneReplacement |= replaceConstantUsesOf(
- rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides,
- ShapedType::isDynamicStrideOrOffset);
-
- return success(atLeastOneReplacement);
- }
-};
-
/// Replace `base, offset, sizes, strides =
/// extract_strided_metadata(allocLikeOp)`
///
@@ -753,7 +695,6 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns(
memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
ExtractStridedMetadataOpReshapeFolder<
memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>,
- ForwardStaticMetadata,
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f20838a470175..5a418022800cf 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -758,9 +758,13 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
// Check that a reinterpret cast of an equivalent extract strided metadata
// is canonicalized to a plain cast when the destination type is
diff erent
// than the type of the original memref.
+// This pattern is currently defeated by the constant folding that happens
+// with extract_strided_metadata.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
+// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]]
// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -773,12 +777,12 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
// Check that a reinterpret cast of an equivalent extract strided metadata
// is completely removed when the original memref has the same type.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
-// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32
// CHECK: return %[[ARG]]
-func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2xf32>) -> memref<8x2xf32> {
- %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
- %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<8x2xf32>
- return %m2 : memref<8x2xf32>
+func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?xf32, strided<[?,?], offset: ?>>) -> memref<?x?xf32, strided<[?,?], offset: ?>> {
+ %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<?x?xf32, strided<[?,?], offset: ?>> -> memref<f32>, index, index, index, index, index
+ %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?,?], offset:?>>
+ return %m2 : memref<?x?xf32, strided<[?,?], offset:?>>
}
// -----
@@ -787,8 +791,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -801,8 +807,11 @@ func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1]
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index caa7efdcc6c3a..4648d9e4ec74d 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -193,10 +193,10 @@ func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(
-> (memref<f32>, index, index, index, index, index) {
%subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] :
- memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
+ memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
- memref<6x3xf32, strided<[4, 1], offset: 210>>
+ memref<6x3xf32, strided<[?, 1], offset: 210>>
-> memref<f32>, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
More information about the Mlir-commits
mailing list