[Mlir-commits] [mlir] 4abccd3 - [mlir][memref][transform] Register memref dialect patterns
Matthias Springer
llvmlistbot at llvm.org
Sun Jun 4 23:50:05 PDT 2023
Author: Matthias Springer
Date: 2023-06-05T08:43:39+02:00
New Revision: 4abccd3913be0fc56e0383e04b3c0a4b872db767
URL: https://github.com/llvm/llvm-project/commit/4abccd3913be0fc56e0383e04b3c0a4b872db767
DIFF: https://github.com/llvm/llvm-project/commit/4abccd3913be0fc56e0383e04b3c0a4b872db767.diff
LOG: [mlir][memref][transform] Register memref dialect patterns
Differential Revision: https://reviews.llvm.org/D151998
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 1fe13423fa5ce..91ef1620fce64 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-/// This header declares functions that assit transformations in the MemRef
+/// This header declares functions that assist transformations in the MemRef
/// dialect.
//
//===----------------------------------------------------------------------===//
@@ -44,9 +44,9 @@ void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
-/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
/// operands.
-void populateResolveRankedShapeTypeResultDimsPatterns(
+void populateResolveRankedShapedTypeResultDimsPatterns(
RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
@@ -68,7 +68,7 @@ void populateMemRefWideIntEmulationPatterns(
arith::WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns);
-/// Appends type converions for emulating wide integer memref operations with
+/// Appends type conversions for emulating wide integer memref operations with
/// ops over narrowe integer types.
void populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d8eccb9675894..f23830699aeb9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -675,7 +675,7 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
- memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
@@ -689,7 +689,7 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
- memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index f8b449122feb5..7b636137e20ac 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -161,6 +162,23 @@ class MemRefTransformDialectExtension
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
>();
+
+ addDialectDataInitializer<transform::PatternRegistry>(
+ [&](transform::PatternRegistry ®istry) {
+ registry.registerPatterns("memref.expand_ops",
+ memref::populateExpandOpsPatterns);
+ registry.registerPatterns("memref.fold_memref_alias_ops",
+ memref::populateFoldMemRefAliasOpPatterns);
+ registry.registerPatterns(
+ "memref.resolve_ranked_shaped_type_result_dims",
+ memref::populateResolveRankedShapedTypeResultDimsPatterns);
+ registry.registerPatterns(
+ "memref.expand_strided_metadata",
+ memref::populateExpandStridedMetadataPatterns);
+ registry.registerPatterns(
+ "memref.extract_address_computations",
+ memref::populateExtractAddressComputationsPatterns);
+ });
}
};
} // namespace
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 526c1c6e198ff..9e5fc73bea06a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -121,7 +121,7 @@ struct ResolveShapedTypeResultDimsPass final
} // namespace
-void memref::populateResolveRankedShapeTypeResultDimsPatterns(
+void memref::populateResolveRankedShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
@@ -138,14 +138,14 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list