[Mlir-commits] [mlir] [mlir][vector] Add `SwapShapeCastOfTranspose` canonicalizer pattern (PR #100933)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jul 29 07:24:47 PDT 2024
================
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+ return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+ SmallVector<bool> scalableFlags;
+ SmallVector<int64_t> dimSizes;
+ for (auto dim : getDims(vType)) {
+ if (dim == std::make_tuple(1, false))
+ continue;
+ auto [size, scalableFlag] = dim;
+ dimSizes.push_back(size);
+ scalableFlags.push_back(scalableFlag);
+ }
+ return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeOp =
+ shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (resultType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+ if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+ auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+ auto transposeSourceDims =
+ llvm::to_vector(getDims(transposeSourceVectorType));
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ auto perm = transposeOp.getPermutation();
----------------
kuhar wrote:
It's not obvious to me what the type of `perm` is here (`ArrayRef<int64_t>`?). Could we use the actual type instead of `auto`?
https://github.com/llvm/llvm-project/pull/100933
More information about the Mlir-commits
mailing list