[Mlir-commits] [mlir] [mlir][vector]Enable DropUnitDimFromTransposeOp (PR #93007)

Han-Chung Wang llvmlistbot at llvm.org
Thu May 30 15:03:14 PDT 2024


================
@@ -1695,6 +1695,77 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// Removes unit dimensions from a transpose op. Generates a vector.shape_cast
+/// on the operand and result to match types.
+///
+/// Ex:
+/// ```
+///   %tr = vector.transpose %arg0, [3, 1, 2, 0]: vector<1x4x1x2xf32> to
+///   vector<2x4x1x1xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %sc0 = vector.shape_cast %arg0 : vector<1x4x1x2xf32> to vector<4x2xf32>
+/// %tr = vector.transpose %sc0, [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+/// %sc1 = vector.shape_cast %tr : vector<2x4xf32> to vector<2x4x1x1xf32>
+/// ```
+struct DropUnitDimFromTransposeOp final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceVectorType = transposeOp.getSourceVectorType();
+    if (sourceVectorType.getRank() < 2)
+      return failure();
+
+    VectorType newVType = sourceVectorType;
+    SmallVector<int64_t> newPerm =
+        llvm::to_vector(transposeOp.getPermutation());
+    unsigned removedDims = 0;
+    auto shape = sourceVectorType.getShape();
+    for (const auto &dim : llvm::enumerate(shape)) {
+      if (dim.value() == 1 &&
+          !sourceVectorType.getScalableDims()[dim.index()]) {
+        newVType =
+            VectorType::Builder(newVType).dropDim(dim.index() - removedDims);
+        for (unsigned permutationIdx = 0; permutationIdx < newPerm.size();
+             ++permutationIdx) {
+          // Erase from permutation map the dropped unary dimension.
+          if ((unsigned)newPerm[permutationIdx] == dim.index() - removedDims) {
+            newPerm.erase(newPerm.begin() + permutationIdx);
+            permutationIdx--;
+          }
+          // Decrement all dimensions of higher rank to keep permutation map
+          // in range of the new rank.
+          else if ((unsigned)newPerm[permutationIdx] >
+                   dim.index() - removedDims) {
+            newPerm[permutationIdx]--;
+          }
+        }
+        removedDims++;
+      }
+    }
+    if (!removedDims)
+      return failure();
+
+    auto loc = transposeOp->getLoc();
+    auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType,
+                                                     transposeOp.getVector());
+    // Create an updated Transpose Op without unit dim.
+    vector::TransposeOp newTransposeOp =
+        rewriter.create<vector::TransposeOp>(loc, opSC, newPerm);
+
+    // Restore the unit dim by applying vector.shape_cast to the result.
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        transposeOp, transposeOp.getResultVectorType(), newTransposeOp);
+
+    return failure();
----------------
hanhanW wrote:

`return success()`?

https://github.com/llvm/llvm-project/pull/93007


More information about the Mlir-commits mailing list