[Mlir-commits] [mlir] [mlir] Add pattern to bubble up `vector.shape_cast` to cancel them (PR #75881)
Jerry Wu
llvmlistbot at llvm.org
Mon Dec 18 17:14:07 PST 2023
https://github.com/pzread created https://github.com/llvm/llvm-project/pull/75881
None
>From 7ff89545bbca9889ba89fe72704e63ef2fcc2756 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 19 Dec 2023 01:08:12 +0000
Subject: [PATCH] Add pattern to bubble up shape_cast
---
.../Vector/Transforms/VectorRewritePatterns.h | 4 ++
.../Vector/Transforms/VectorTransforms.cpp | 40 +++++++++++++++++++
2 files changed, 44 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..e9d09c4b6c3a86 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -303,6 +303,10 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// TODO
+void populateBubbleShapeCastPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Collect a set of patterns to flatten n-D vector transfers on contiguous
/// memref.
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..4ecf9250427ea0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1532,6 +1532,40 @@ struct DropUnitDimFromElementwiseOps final
}
};
+struct BubbleUpShapeCastForElementwiseOps final
+ : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *sourceOp = op.getSource().getDefiningOp();
+ if (!(sourceOp && OpTrait::hasElementwiseMappableTraits(sourceOp) &&
+ sourceOp->getNumResults() == 1 && sourceOp->getNumRegions() == 0)) {
+ return failure();
+ }
+
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
+ auto loc = op.getLoc();
+
+ SmallVector<Value> newOperands;
+ for (Value operand : sourceOp->getOperands()) {
+ Type elementType = operand.getType().cast<VectorType>().getElementType();
+ VectorType newOperandType =
+ VectorType::Builder(resultType).setElementType(elementType);
+ auto castOp =
+ rewriter.create<vector::ShapeCastOp>(loc, newOperandType, operand);
+ newOperands.push_back(castOp);
+ }
+
+ Operation *elementwiseOp =
+ rewriter.create(loc, sourceOp->getName().getIdentifier(), newOperands,
+ resultType, sourceOp->getAttrs());
+ rewriter.replaceOp(op, elementwiseOp);
+
+ return success();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1606,6 +1640,12 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
patterns.getContext(), benefit);
}
+void mlir::vector::populateBubbleShapeCastPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<BubbleUpShapeCastForElementwiseOps, ShapeCastOpFolder>(
+ patterns.getContext(), benefit);
+}
+
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
More information about the Mlir-commits
mailing list