[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dims from vector.transpose (PR #102017)

Jakub Kuderski llvmlistbot at llvm.org
Wed Aug 7 07:30:34 PDT 2024


================
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// A pattern to drop unit dims from vector.transpose.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %transpose = vector.transpose %vector, [3, 0, 1, 2]
+///    : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %dropDims = vector.shape_cast %vector
+///    : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %transpose = vector.transpose %0, [1, 0]
+///    : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  %restoreDims = vector.shape_cast %transpose
+///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///  ```
+struct DropUnitDimsFromTransposeOp final
+    : OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType sourceTypeWithoutUnitDims =
+        dropNonScalableUnitDimFromType(sourceType);
+
+    if (sourceType == sourceTypeWithoutUnitDims)
+      return failure();
+
+    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+    auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
+    SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
+    int64_t droppedDims = 0;
+    for (auto [i, dim] : llvm::enumerate(sourceDims)) {
+      droppedDimsBefore[i] = droppedDims;
+      if (dim == std::make_tuple(1, false))
+        ++droppedDims;
+    }
+
+    // Drop unit dims from transpose permutation.
+    ArrayRef<int64_t> perm = op.getPermutation();
+    SmallVector<int64_t> newPerm;
+    for (int64_t idx : perm) {
+      if (sourceDims[idx] == std::make_tuple(1, false))
+        continue;
+      newPerm.push_back(idx - droppedDimsBefore[idx]);
+    }
+
+    auto loc = op.getLoc();
----------------
kuhar wrote:

ubernit
```suggestion
    Location loc = op.getLoc();
```

https://github.com/llvm/llvm-project/pull/102017


More information about the Mlir-commits mailing list