[Mlir-commits] [mlir] [mlir][vector] Add `extract(transpose(broadcast(x)))` canonicalization (PR #72616)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Dec 14 05:41:19 PST 2023


================
@@ -1898,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+    : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Skip vector.extract ops that do not remove any dimensions.
+    if (extractOp.getNumIndices() == 0)
+      return failure();
+    // Look for extract(transpose(broadcast(x))) pattern.
+    auto transposeOp =
+        extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp || transposeOp.getPermutation().empty())
+      return failure();
+    auto broadcastOp =
+        transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcastOp)
+      return failure();
+    // Check if the first dimension that is being removed by the vector.extract
+    // was added by the vector.broadcast.
+    int64_t removedDim = transposeOp.getPermutation()[0];
+    llvm::SetVector<int64_t> rankExtendedDims =
+        broadcastOp.computeRankExtendedDims();
+    if (!rankExtendedDims.contains(removedDim))
+      return failure();
+
+    // 1. Create new vector.broadcast without the removed dimension.
+    SmallVector<int64_t> newBroadcastShape(
+        broadcastOp.getResultVectorType().getShape());
+    newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+    auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+        broadcastOp.getLoc(),
+        VectorType::get(newBroadcastShape,
+                        broadcastOp.getResultVectorType().getElementType()),
----------------
banach-space wrote:

This won't preserve scalability, but this would. Try this:
```
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 957143d6c13e..5edff9dd32d1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1930,10 +1930,15 @@ public:
     SmallVector<int64_t> newBroadcastShape(
         broadcastOp.getResultVectorType().getShape());
     newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+    SmallVector<bool> newBroadcastScalableDims(
+        broadcastOp.getResultVectorType().getScalableDims());
+    newBroadcastScalableDims.erase(newBroadcastScalableDims.begin() +
+                                   removedDim);
     auto newBroadcast = rewriter.create<vector::BroadcastOp>(
         broadcastOp.getLoc(),
         VectorType::get(newBroadcastShape,
-                        broadcastOp.getResultVectorType().getElementType()),
+                        broadcastOp.getResultVectorType().getElementType(),
+                        newBroadcastScalableDims),
         broadcastOp.getSource());
```
Could you also add a test with at least 1 scalable dim? Thanks :)

https://github.com/llvm/llvm-project/pull/72616


More information about the Mlir-commits mailing list