[Mlir-commits] [mlir] f7e1ce0 - [mlir][MemRef] Add pattern that forwards constant strided metadata.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Sep 26 08:34:40 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-26T08:34:31-07:00
New Revision: f7e1ce0f3071849fe9c7932a92038c51105fd8bc
URL: https://github.com/llvm/llvm-project/commit/f7e1ce0f3071849fe9c7932a92038c51105fd8bc
DIFF: https://github.com/llvm/llvm-project/commit/f7e1ce0f3071849fe9c7932a92038c51105fd8bc.diff
LOG: [mlir][MemRef] Add pattern that forwards constant strided metadata.
`memref.extract_strided_metadata` can forward constants independently of the
exsistence of other operations such as subview or reshape.
Differential Revision: https://reviews.llvm.org/D134603
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 16d17aa0b183b..dcadd5b33d078 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
namespace mlir {
@@ -428,13 +429,72 @@ struct ExtractStridedMetadataOpExpandShapeFolder
return success();
}
};
+
+/// Helper function to perform the replacement of all constant uses of `values`
+/// by a materialized constant extracted from `maybeConstants`.
+/// `values` and `maybeConstants` are expected to have the same size.
+template <typename Container>
+bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc,
+ Container values, ArrayRef<int64_t> maybeConstants,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ assert(values.size() == maybeConstants.size() &&
+ " expected values and maybeConstants of the same size");
+ bool atLeastOneReplacement = false;
+ for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
+ // Don't materialize a constant if there are no uses: this would indice
+ // infinite loops in the driver.
+ if (isDynamic(maybeConstant) || result.use_empty())
+ continue;
+ Value constantVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
+ for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
+ rewriter.startRootUpdate(op);
+ // updateRootInplace: lambda cannot capture structured bindings in C++17
+ // yet.
+ op->replaceUsesOfWith(result, constantVal);
+ rewriter.finalizeRootUpdate(op);
+ atLeastOneReplacement = true;
+ }
+ }
+ return atLeastOneReplacement;
+}
+
+// Forward propagate all constants information from an ExtractStridedMetadataOp.
+struct ForwardStaticMetadata
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
+ PatternRewriter &rewriter) const override {
+ auto memrefType = metadataOp.getSource().getType().cast<MemRefType>();
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
+ (void)res;
+ assert(succeeded(res) && "must be a strided memref type");
+
+ bool atLeastOneReplacement = replaceConstantUsesOf(
+ rewriter, metadataOp.getLoc(),
+ ArrayRef<TypedValue<IndexType>>(metadataOp.getOffset()),
+ ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
+ atLeastOneReplacement |= replaceConstantUsesOf(
+ rewriter, metadataOp.getLoc(), metadataOp.getSizes(),
+ memrefType.getShape(), ShapedType::isDynamic);
+ atLeastOneReplacement |= replaceConstantUsesOf(
+ rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides,
+ ShapedType::isDynamicStrideOrOffset);
+
+ return success(atLeastOneReplacement);
+ }
+};
} // namespace
void memref::populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns) {
- patterns.add<ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpExpandShapeFolder>(
- patterns.getContext());
+ patterns
+ .add<ExtractStridedMetadataOpSubviewFolder,
+ ExtractStridedMetadataOpExpandShapeFolder, ForwardStaticMetadata>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 3b2b00d2dc6fa..dd5889fe47fd9 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -1,5 +1,24 @@
// RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s
+// CHECK-LABEL: func @extract_strided_metadata_constants
+// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)
+func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>)
+ -> (memref<f32>, index, index, index, index, index) {
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+
+ // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+ %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base :
+ memref<5x4xf32, strided<[4,1], offset:2>>
+ -> memref<f32>, index, index, index, index, index
+
+ // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]]
+ return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+ memref<f32>, index, index, index, index, index
+}
+
// -----
// Check that we simplify extract_strided_metadata of subview to
More information about the Mlir-commits
mailing list