[Mlir-commits] [mlir] [mlir][Vector] Replace vector.transpose with vector.shape_cast (PR #94912)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 9 13:29:43 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Prashant Kumar (pashu123)
<details>
<summary>Changes</summary>
Suppose the permutation width is defined as the last index in the permutation array that is not equal to its index. This pattern is applied to transpose operations where the input vector has a shape with at most one non-unit dimension up to the permutation width. The pattern replaces the transpose operation with a shape cast operation. For example:
%0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
is replaced by
%0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
given the permutation width is 2.
---
Full diff: https://github.com/llvm/llvm-project/pull/94912.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+56-2)
- (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+7)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index ca8a6f6d82a6e..a1bfb4063f756 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -451,6 +451,59 @@ class Transpose2DWithUnitDimToShapeCast
}
};
+// Suppose the permutation width is defined as the last index in the permutation
+// array that is not equal to its index. This pattern is applied to transpose
+// operations where the input vector has a shape with at most one non-unit
+// dimension up to the permutation width. The pattern replaces the transpose
+// operation with a shape cast operation.
+// For example:
+// %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
+// is replaced by
+// %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
+// given the permutation width is 2.
+class TransposeWithUnitDimToShapeCast
+ : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ TransposeWithUnitDimToShapeCast(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.getVector();
+ VectorType inputType = op.getSourceVectorType();
+ if (inputType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "This lowering does not support scalable vectors");
+ VectorType resType = op.getResultVectorType();
+
+ ArrayRef<int64_t> transp = op.getPermutation();
+
+ // Get the permutation width.
+ int64_t permWidth = 1;
+ for (auto &&[idx, val] : llvm::enumerate(transp)) {
+ if (static_cast<int64_t>(idx) != val)
+ permWidth = idx + 1;
+ }
+
+ // Check the no. of non unit dim in the input shape upto permutation width
+ // is not greater than one.
+ auto inputShape = inputType.getShape();
+
+ int64_t countNonUnitDims = 0;
+ for (int i = 0; i < permWidth; i++) {
+ if (inputShape[i] != 1)
+ countNonUnitDims++;
+ if (countNonUnitDims > 1)
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+ return success();
+ }
+};
+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -523,8 +576,9 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
- patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
- benefit);
+ patterns
+ .add<Transpose2DWithUnitDimToShapeCast, TransposeWithUnitDimToShapeCast>(
+ patterns.getContext(), benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 219a72df52a19..d50d8d0d67da1 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -386,6 +386,13 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
return %0 : vector<[1]x4xf32>
}
+// CHECK-LABEL: func @transpose_nd
+func.func @transpose_nd(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
+ // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+ %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+ return %0 : vector<1x1x2x16xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
``````````
</details>
https://github.com/llvm/llvm-project/pull/94912
More information about the Mlir-commits
mailing list