[Mlir-commits] [mlir] [mlir][vector] Add extract(transpose(broadcast(x))) canonicalization (PR #72616)

Matthias Springer llvmlistbot at llvm.org
Thu Nov 16 23:32:29 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/72616

Fold IR patterns where an extract is removing a dimension that was added by a broadcast, and where there is a transpose op between these two ops.

Depends on #72594. Review only the top commit.


>From a4b57fdb77f637f82bb556d74f146138bbf1fab8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 11:10:50 +0900
Subject: [PATCH 1/2] [mlir][vector] Modernize `vector.transpose` op

* Declare arguments/results with `let` statements.
* Rename `transp` to `permutation`.
* Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr`.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 17 +++--
 mlir/include/mlir/IR/AffineMap.h              |  2 +
 .../VectorToArmSME/VectorToArmSME.cpp         |  7 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  7 +-
 .../Dialect/Arith/Transforms/IntNarrowing.cpp |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 64 +++++++------------
 .../Transforms/LowerVectorTranspose.cpp       |  4 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 20 +++---
 .../Vector/Transforms/VectorUnroll.cpp        |  3 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  7 +-
 mlir/lib/IR/AffineMap.cpp                     |  6 ++
 11 files changed, 59 insertions(+), 80 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..1397d4caf1d9d61 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2436,14 +2436,13 @@ def Vector_TransposeOp :
   Vector_Op<"transpose", [Pure,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
-    Results<(outs AnyVectorOfAnyRank:$result)> {
+                 TCresVTEtIsSameAsOpBase<0, 0>>]> {
   let summary = "vector transpose operation";
   let description = [{
     Takes a n-D vector and returns the transposed n-D vector defined by
     the permutation of ranks in the n-sized integer array attribute (in case
     of 0-D vectors the array attribute must be empty).
+
     In the operation
 
     ```mlir
@@ -2452,7 +2451,7 @@ def Vector_TransposeOp :
       to vector<d_trans[0] x .. x d_trans[n-1] x f32>
     ```
 
-    the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+    the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
 
     Example:
 
@@ -2464,8 +2463,13 @@ def Vector_TransposeOp :
                           [c, f] ]
     ```
   }];
+
+  let arguments = (ins AnyVectorOfAnyRank:$vector,
+                       DenseI64ArrayAttr:$permutation);
+  let results = (outs AnyVectorOfAnyRank:$result);
+
   let builders = [
-    OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+    OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
   ];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
@@ -2474,10 +2478,9 @@ def Vector_TransposeOp :
     VectorType getResultVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    void getTransp(SmallVectorImpl<int64_t> &results);
   }];
   let assemblyFormat = [{
-    $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+    $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..981f3d392cbc98c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
   /// (i.e. `[1,1,2]` is an invalid permutation).
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
+  static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+                                     MLIRContext *context);
 
   /// Returns an affine map with `numDims` input dimensions and results
   /// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    SmallVector<int64_t> transp;
-    for (auto attr : transposeOp.getTransp())
-      transp.push_back(cast<IntegerAttr>(attr).getInt());
-
     // Bail unless this is a true 2-D matrix transpose.
-    if (transp[0] != 1 || transp[1] != 0)
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    if (permutation[0] != 1 || permutation[1] != 0)
       return failure();
 
     OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
 
-    SmallVector<int64_t, 2> perm;
-    op.getTransp(perm);
-    SmallVector<unsigned, 2> permU;
-    for (int64_t o : perm)
-      permU.push_back(unsigned(o));
     AffineMap permutationMap =
-        AffineMap::getPermutationMap(permU, op.getContext());
+        AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
     AffineMap newMap =
         permutationMap.compose(transferReadOp.getPermutationMap());
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
     VectorType newTy =
         origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newTranspose = rewriter.create<vector::TransposeOp>(
-        op.getLoc(), newTy, ext->getIn(), op.getTransp());
+        op.getLoc(), newTy, ext->getIn(), op.getPermutation());
     ext->recreateAndReplace(rewriter, op, newTranspose);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..c7b74701fdbc8f2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
 
   if (!nextTransposeOp)
     return failure();
-  auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
-  AffineMap m = inversePermutation(
-      AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+      nextTransposeOp.getPermutation(), extractOp.getContext()));
   extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
   return success();
 }
@@ -5376,20 +5375,20 @@ LogicalResult TypeCastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
-                                Value vector, ArrayRef<int64_t> transp) {
+                                Value vector, ArrayRef<int64_t> permutation) {
   VectorType vt = llvm::cast<VectorType>(vector.getType());
   SmallVector<int64_t, 4> transposedShape(vt.getRank());
   SmallVector<bool, 4> transposedScalableDims(vt.getRank());
-  for (unsigned i = 0; i < transp.size(); ++i) {
-    transposedShape[i] = vt.getShape()[transp[i]];
-    transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
+  for (unsigned i = 0; i < permutation.size(); ++i) {
+    transposedShape[i] = vt.getShape()[permutation[i]];
+    transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
   }
 
   result.addOperands(vector);
   result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
                                   transposedScalableDims));
-  result.addAttribute(TransposeOp::getTranspAttrName(result.name),
-                      builder.getI64ArrayAttr(transp));
+  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
+                      builder.getDenseI64ArrayAttr(permutation));
 }
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5400,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
-  SmallVector<int64_t, 4> transp;
-  getTransp(transp);
+  ArrayRef<int64_t> perm = getPermutation();
 
   // Check if the permutation of the dimensions contains sequential values:
   // {0, 1, 2, ...}.
-  for (int64_t i = 0, e = transp.size(); i < e; i++) {
-    if (transp[i] != i)
+  for (int64_t i = 0, e = perm.size(); i < e; i++) {
+    if (perm[i] != i)
       return {};
   }
 
@@ -5421,20 +5419,19 @@ LogicalResult vector::TransposeOp::verify() {
   if (vectorType.getRank() != rank)
     return emitOpError("vector result rank mismatch: ") << rank;
   // Verify transposition array.
-  auto transpAttr = getTransp().getValue();
-  int64_t size = transpAttr.size();
+  ArrayRef<int64_t> perm = getPermutation();
+  int64_t size = perm.size();
   if (rank != size)
     return emitOpError("transposition length mismatch: ") << size;
   SmallVector<bool, 8> seen(rank, false);
-  for (const auto &ta : llvm::enumerate(transpAttr)) {
-    int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
-    if (i < 0 || i >= rank)
-      return emitOpError("transposition index out of range: ") << i;
-    if (seen[i])
-      return emitOpError("duplicate position index: ") << i;
-    seen[i] = true;
-    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
-      return emitOpError("dimension size mismatch at: ") << i;
+  for (const auto &ta : llvm::enumerate(perm)) {
+    if (ta.value() < 0 || ta.value() >= rank)
+      return emitOpError("transposition index out of range: ") << ta.value();
+    if (seen[ta.value()])
+      return emitOpError("duplicate position index: ") << ta.value();
+    seen[ta.value()] = true;
+    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
+      return emitOpError("dimension size mismatch at: ") << ta.value();
   }
   return success();
 }
@@ -5452,13 +5449,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
-    auto getPermutation = [](vector::TransposeOp transpose) {
-      SmallVector<int64_t, 4> permutation;
-      transpose.getTransp(permutation);
-      return permutation;
-    };
-
     // Composes two permutations: result[i] = permutation1[permutation2[i]].
     auto composePermutations = [](ArrayRef<int64_t> permutation1,
                                   ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5465,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
       return failure();
 
     SmallVector<int64_t, 4> permutation = composePermutations(
-        getPermutation(parentTransposeOp), getPermutation(transposeOp));
+        parentTransposeOp.getPermutation(), transposeOp.getPermutation());
     // Replace 'transposeOp' with a new transpose operation.
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
-        parentTransposeOp.getVector(),
-        vector::getVectorSubscriptAttr(rewriter, permutation));
+        parentTransposeOp.getVector(), permutation);
     return success();
   }
 };
@@ -5539,8 +5528,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 
     // Get the transpose permutation and apply it to the vector.create_mask or
     // vector.constant_mask operands.
-    SmallVector<int64_t> permutation;
-    transpOp.getTransp(permutation);
+    ArrayRef<int64_t> permutation = transpOp.getPermutation();
 
     if (createMaskOp) {
       auto maskOperands = createMaskOp.getOperands();
@@ -5572,10 +5560,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
               TransposeFolder, FoldTransposeSplat>(context);
 }
 
-void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
-  populateFromInt64AttrArray(getTransp(), results);
-}
-
 //===----------------------------------------------------------------------===//
 // ConstantMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..97f6caca1b25ccc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     VectorType resType = op.getResultVectorType();
 
     // Set up convenience transposition table.
-    SmallVector<int64_t> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(cast<IntegerAttr>(attr).getInt());
+    ArrayRef<int64_t> transp = op.getPermutation();
 
     if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
         succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 713f9cb72c82cec..a20c8aeeb6f7108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
       if (!transposeOp)
         continue;
       AffineMap permutationMap = AffineMap::getPermutationMap(
-          extractVector<unsigned>(transposeOp.getTransp()),
-          contractOp.getContext());
+          transposeOp.getPermutation(), contractOp.getContext());
       map = inversePermutation(permutationMap).compose(map);
       *operand = transposeOp.getVector();
       changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
 
     // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
     // To index into A in contract, we need revert(f)(g(C)) -> A.
-    auto accTMap = AffineMap::getPermutationMap(
-        extractVector<unsigned>(accTOp.getTransp()), context);
+    auto accTMap =
+        AffineMap::getPermutationMap(accTOp.getPermutation(), context);
 
     // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
     // To index into E in contract, we need h(g(C)) -> E.
-    auto resTMap = AffineMap::getPermutationMap(
-        extractVector<unsigned>(resTOp.getTransp()), context);
+    auto resTMap =
+        AffineMap::getPermutationMap(resTOp.getPermutation(), context);
     auto combinedResMap = resTMap.compose(contractMap);
 
     // The accumulator and result share the same indexing map. So they should be
@@ -490,7 +489,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // Make sure all operands are transpose/constant ops and collect their
     // transposition maps.
-    SmallVector<ArrayAttr> transposeMaps;
+    SmallVector<ArrayRef<int64_t>> transposeMaps;
     transposeMaps.reserve(op->getNumOperands());
     // Record the initial type before transposition. We'll use its shape later.
     // Any type will do here as we will check all transpose maps are the same.
@@ -498,7 +497,7 @@ struct ReorderElementwiseOpsOnTranspose final
     for (Value operand : op->getOperands()) {
       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
       if (transposeOp) {
-        transposeMaps.push_back(transposeOp.getTransp());
+        transposeMaps.push_back(transposeOp.getPermutation());
         srcType = transposeOp.getSourceVectorType();
       } else if (!matchPattern(operand, m_Constant())) {
         return failure();
@@ -517,7 +516,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // If there are constant operands, we need to insert inverse transposes for
     // them. Calculate the inverse order first.
-    auto order = extractVector<unsigned>(transposeMaps.front());
+    auto order = transposeMaps.front();
     SmallVector<int64_t> invOrder(order.size());
     for (int i = 0, e = order.size(); i < e; ++i)
       invOrder[order[i]] = i;
@@ -532,8 +531,7 @@ struct ReorderElementwiseOpsOnTranspose final
             srcType.getShape(),
             cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
-            operand.getLoc(), vectorType, operand,
-            rewriter.getI64ArrayAttr(invOrder)));
+            operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4cfac7de29ee76f..78b041255443c30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -537,8 +537,7 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
-    SmallVector<int64_t> permutation;
-    transposeOp.getTransp(permutation);
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
 
     // Unroll the computation.
     for (SmallVector<int64_t> elementOffsets :
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 467a521f9eada96..48cd67ad86c63fb 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -87,14 +87,11 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
   if (srcGtOneDims.size() != 2)
     return failure();
 
-  SmallVector<int64_t> transp;
-  for (auto attr : op.getTransp())
-    transp.push_back(cast<IntegerAttr>(attr).getInt());
-
   // Check whether the two source vector dimensions that are greater than one
   // must be transposed with each other so that we can apply one of the 2-D
   // transpose pattens. Otherwise, these patterns are not applicable.
-  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
+                                  op.getPermutation()))
     return failure();
 
   return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 93a8d048e0a61d5..80a26a595edee0a 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -236,6 +236,12 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
   assert(permutationMap.isPermutation() && "Invalid permutation vector");
   return permutationMap;
 }
+AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
+                                       MLIRContext *context) {
+  SmallVector<unsigned> perm = llvm::map_to_vector(
+      permutation, [](int64_t i) { return static_cast<unsigned>(i); });
+  return AffineMap::getPermutationMap(perm, context);
+}
 
 AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
                                                ArrayRef<unsigned> targets,

>From 5b48e0b0f5eba3f2b0029617ad7313aa10a0c266 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 16:29:19 +0900
Subject: [PATCH 2/2] [mlir][vector] Add extract(transpose(broadcast(x)))
 canonicalization

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  9 ++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 73 ++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir    | 14 ++++
 3 files changed, 93 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1397d4caf1d9d61..49860cadcd12c26 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -386,10 +386,15 @@ def Vector_BroadcastOp :
       return ::llvm::cast<VectorType>(getVector().getType());
     }
 
-    /// Return the dimensions of the result vector that were formerly ones in the
-    /// source tensor and thus correspond to "dim-1" broadcasting.
+    /// Return the dimensions of the result vector that were formerly ones in
+    /// the source vector and thus correspond to "dim-1" broadcasting.
     llvm::SetVector<int64_t> computeBroadcastedUnitDims();
 
+    /// Return the dimensions of the result vector that were newly added to the
+    /// source vector via rank extension. These are all the dimensions that were
+    /// not "dim-1" broadcasted.
+    llvm::SetVector<int64_t> computeRankExtendedDims();
+
     /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
     /// `broadcastedDims` dimensions in the dstShape are broadcasted.
     /// This requires (and asserts) that the broadcast is free of dim-1
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..957143d6c13e9e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1897,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+    : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Skip vector.extract ops that do not remove any dimensions.
+    if (extractOp.getNumIndices() == 0)
+      return failure();
+    // Look for extract(transpose(broadcast(x))) pattern.
+    auto transposeOp =
+        extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp || transposeOp.getPermutation().empty())
+      return failure();
+    auto broadcastOp =
+        transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcastOp)
+      return failure();
+    // Check if the first dimension that is being removed by the vector.extract
+    // was added by the vector.broadcast.
+    int64_t removedDim = transposeOp.getPermutation()[0];
+    llvm::SetVector<int64_t> rankExtendedDims =
+        broadcastOp.computeRankExtendedDims();
+    if (!rankExtendedDims.contains(removedDim))
+      return failure();
+
+    // 1. Create new vector.broadcast without the removed dimension.
+    SmallVector<int64_t> newBroadcastShape(
+        broadcastOp.getResultVectorType().getShape());
+    newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+    auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+        broadcastOp.getLoc(),
+        VectorType::get(newBroadcastShape,
+                        broadcastOp.getResultVectorType().getElementType()),
+        broadcastOp.getSource());
+
+    // 2. Create new vector.transpose.
+    SmallVector<int64_t> newPermutation;
+    for (int64_t dim : transposeOp.getPermutation().drop_front())
+      newPermutation.push_back(dim < transposeOp.getPermutation()[0] ? dim
+                                                                     : dim - 1);
+    auto newTranspose = rewriter.create<vector::TransposeOp>(
+        transposeOp.getLoc(), newBroadcast, newPermutation);
+
+    // 3. Create new vector.extract without the outermost dimension.
+    SmallVector<OpFoldResult> mixedPositions = extractOp.getMixedPosition();
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, newTranspose, ArrayRef(mixedPositions).drop_front());
+
+    return success();
+  }
+};
+
 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2062,7 +2118,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+              ExtractOpFromBroadcast, ExtractOpTransposedBroadcastDim,
+              ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
 }
 
@@ -2112,6 +2169,20 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
                                       getResultVectorType().getShape());
 }
 
+llvm::SetVector<int64_t> BroadcastOp::computeRankExtendedDims() {
+  llvm::SetVector<int64_t> broadcastedUnitDims = computeBroadcastedUnitDims();
+  llvm::SetVector<int64_t> result;
+  auto vecSrcType = dyn_cast<VectorType>(getSourceType());
+  int64_t rankDiff =
+      vecSrcType ? getResultVectorType().getRank() - vecSrcType.getRank()
+                 : getResultVectorType().getRank();
+  for (int64_t i = 0; i < rankDiff; ++i) {
+    if (!broadcastedUnitDims.contains(i))
+      result.insert(i);
+  }
+  return result;
+}
+
 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
 /// This requires (and asserts) that the broadcast is free of dim-1
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..a6b4f7f2717da81 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2524,3 +2524,17 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_of_transposed_broadcast_dim(
+//  CHECK-SAME:     %[[arg0:.*]]: vector<4x1xf32>
+//       CHECK:   %[[bc:.*]] = vector.broadcast %[[arg0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
+//       CHECK:   %[[tp:.*]] = vector.transpose %[[bc]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+//       CHECK:   return %[[tp]]
+func.func @extract_of_transposed_broadcast_dim(%arg0: vector<4x1xf32>) -> vector<1x100x4x5xf32> {
+  %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<100x5x1x4x1xf32>
+  %1 = vector.transpose %0, [2, 4, 0, 3, 1] : vector<100x5x1x4x1xf32> to vector<1x1x100x4x5xf32>
+  %2 = vector.extract %1[0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+  return %2 : vector<1x100x4x5xf32>
+}



More information about the Mlir-commits mailing list