[Mlir-commits] [mlir] [mlir][vector] Add `SwapShapeCastOfTranspose` canonicalizer pattern (PR #100933)

Jakub Kuderski llvmlistbot at llvm.org
Mon Jul 29 07:24:47 PDT 2024


================
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+  SmallVector<bool> scalableFlags;
+  SmallVector<int64_t> dimSizes;
+  for (auto dim : getDims(vType)) {
+    if (dim == std::make_tuple(1, false))
+      continue;
+    auto [size, scalableFlag] = dim;
+    dimSizes.push_back(size);
+    scalableFlags.push_back(scalableFlag);
+  }
+  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.transpose %vector, [3, 0, 1, 2]
+///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp =
+        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (resultType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
----------------
kuhar wrote:

nit:
```suggestion
    VectorType transposeSourceVectorType = transposeOp.getSourceVectorType();
```

https://github.com/llvm/llvm-project/pull/100933


More information about the Mlir-commits mailing list