[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dims from vector.transpose (PR #102017)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Aug 7 07:30:34 PDT 2024
================
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// A pattern to drop unit dims from vector.transpose.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %dropDims = vector.shape_cast %vector
+/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %transpose = vector.transpose %0, [1, 0]
+/// : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// %restoreDims = vector.shape_cast %transpose
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ```
+struct DropUnitDimsFromTransposeOp final
+ : OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType sourceTypeWithoutUnitDims =
+ dropNonScalableUnitDimFromType(sourceType);
+
+ if (sourceType == sourceTypeWithoutUnitDims)
+ return failure();
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
+ SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(sourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ ArrayRef<int64_t> perm = op.getPermutation();
+ SmallVector<int64_t> newPerm;
+ for (int64_t idx : perm) {
+ if (sourceDims[idx] == std::make_tuple(1, false))
+ continue;
+ newPerm.push_back(idx - droppedDimsBefore[idx]);
+ }
+
+ auto loc = op.getLoc();
----------------
kuhar wrote:
ubernit
```suggestion
Location loc = op.getLoc();
```
https://github.com/llvm/llvm-project/pull/102017
More information about the Mlir-commits
mailing list