[Mlir-commits] [mlir] [mlir][VectorOps] Add fold vector.shuffle -> vector.interleave (4/4) (PR #80968)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 21 02:49:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.
For example:
```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```
folds to:
```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```
Depends on: #<!-- -->80967
---
Full diff: https://github.com/llvm/llvm-project/pull/80968.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+42-1)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+23)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5be6a628904cdf..59ca634ce3e3d6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2479,11 +2479,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
}
};
+/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
+/// vector.interleave.
+class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ if (resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp can't represent a scalable interleave");
+
+ if (resultType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp can't represent an n-D interleave");
+
+ VectorType sourceType = op.getV1VectorType();
+ if (sourceType != op.getV2VectorType() ||
+ ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
+ resultType.getShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp types don't match an interleave");
+ }
+
+ ArrayAttr shuffleMask = op.getMask();
+ int64_t resultVectorSize = resultType.getNumElements();
+ for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
+ int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
+ int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+ if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
+ return rewriter.notifyMatchFailure(op,
+ "ShuffleOp mask not interleaving");
+ }
+
+ rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
+ return success();
+ }
+};
+
} // namespace
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
+ results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e5197..4c73a6271786e6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
+// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
+ return %0 : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
+// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+ return %0 : vector<12xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/80968
More information about the Mlir-commits
mailing list