[Mlir-commits] [mlir] [mlir][vector] Add `extract(transpose(broadcast(x)))` canonicalization (PR #72616)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Dec 14 05:41:19 PST 2023
================
@@ -1898,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+ : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Skip vector.extract ops that do not remove any dimensions.
+ if (extractOp.getNumIndices() == 0)
+ return failure();
+ // Look for extract(transpose(broadcast(x))) pattern.
+ auto transposeOp =
+ extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp || transposeOp.getPermutation().empty())
+ return failure();
+ auto broadcastOp =
+ transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp)
+ return failure();
+ // Check if the first dimension that is being removed by the vector.extract
+ // was added by the vector.broadcast.
+ int64_t removedDim = transposeOp.getPermutation()[0];
+ llvm::SetVector<int64_t> rankExtendedDims =
+ broadcastOp.computeRankExtendedDims();
+ if (!rankExtendedDims.contains(removedDim))
+ return failure();
+
+ // 1. Create new vector.broadcast without the removed dimension.
+ SmallVector<int64_t> newBroadcastShape(
+ broadcastOp.getResultVectorType().getShape());
+ newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ broadcastOp.getLoc(),
+ VectorType::get(newBroadcastShape,
+ broadcastOp.getResultVectorType().getElementType()),
----------------
banach-space wrote:
This won't preserve scalability, but this would. Try this:
```
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 957143d6c13e..5edff9dd32d1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1930,10 +1930,15 @@ public:
SmallVector<int64_t> newBroadcastShape(
broadcastOp.getResultVectorType().getShape());
newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+ SmallVector<bool> newBroadcastScalableDims(
+ broadcastOp.getResultVectorType().getScalableDims());
+ newBroadcastScalableDims.erase(newBroadcastScalableDims.begin() +
+ removedDim);
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
broadcastOp.getLoc(),
VectorType::get(newBroadcastShape,
- broadcastOp.getResultVectorType().getElementType()),
+ broadcastOp.getResultVectorType().getElementType(),
+ newBroadcastScalableDims),
broadcastOp.getSource());
```
Could you also add a test with at least 1 scalable dim? Thanks :)
https://github.com/llvm/llvm-project/pull/72616
More information about the Mlir-commits
mailing list