[Mlir-commits] [mlir] [mlir] Add pattern to bubble up `vector.shape_cast` to cancel them (PR #75881)

Jerry Wu llvmlistbot at llvm.org
Mon Dec 18 17:20:01 PST 2023


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/75881

>From 7ff89545bbca9889ba89fe72704e63ef2fcc2756 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 19 Dec 2023 01:08:12 +0000
Subject: [PATCH 1/2] Add pattern to bubble up shape_cast

---
 .../Vector/Transforms/VectorRewritePatterns.h |  4 ++
 .../Vector/Transforms/VectorTransforms.cpp    | 40 +++++++++++++++++++
 2 files changed, 44 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..e9d09c4b6c3a86 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -303,6 +303,10 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
 void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
                                               PatternBenefit benefit = 1);
 
+/// TODO
+void populateBubbleShapeCastPatterns(RewritePatternSet &patterns,
+                                     PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
 /// memref.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..4ecf9250427ea0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1532,6 +1532,40 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+struct BubbleUpShapeCastForElementwiseOps final
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *sourceOp = op.getSource().getDefiningOp();
+    if (!(sourceOp && OpTrait::hasElementwiseMappableTraits(sourceOp) &&
+          sourceOp->getNumResults() == 1 && sourceOp->getNumRegions() == 0)) {
+      return failure();
+    }
+
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType resultType = op.getResultVectorType();
+    auto loc = op.getLoc();
+
+    SmallVector<Value> newOperands;
+    for (Value operand : sourceOp->getOperands()) {
+      Type elementType = operand.getType().cast<VectorType>().getElementType();
+      VectorType newOperandType =
+          VectorType::Builder(resultType).setElementType(elementType);
+      auto castOp =
+          rewriter.create<vector::ShapeCastOp>(loc, newOperandType, operand);
+      newOperands.push_back(castOp);
+    }
+
+    Operation *elementwiseOp =
+        rewriter.create(loc, sourceOp->getName().getIdentifier(), newOperands,
+                        resultType, sourceOp->getAttrs());
+    rewriter.replaceOp(op, elementwiseOp);
+
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1606,6 +1640,12 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateBubbleShapeCastPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit) {
+  patterns.add<BubbleUpShapeCastForElementwiseOps, ShapeCastOpFolder>(
+      patterns.getContext(), benefit);
+}
+
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<BubbleDownVectorBitCastForExtract,

>From fe48318c1bf584536a772f8c5d82599c1819765a Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 19 Dec 2023 01:19:48 +0000
Subject: [PATCH 2/2] Skip zero dim

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 4ecf9250427ea0..3b639e21280d21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1545,6 +1545,11 @@ struct BubbleUpShapeCastForElementwiseOps final
 
     VectorType sourceType = op.getSourceVectorType();
     VectorType resultType = op.getResultVectorType();
+    if (resultType.getShape().size() == 0) {
+      // Some elementwise ops don't support zero dim. For now skip this case.
+      return failure();
+    }
+
     auto loc = op.getLoc();
 
     SmallVector<Value> newOperands;



More information about the Mlir-commits mailing list