[Mlir-commits] [mlir] [mlir][VectorOps] Add fold vector.shuffle -> vector.interleave (4/4) (PR #80968)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Mar 6 05:05:18 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80968

>From d45d47e4a8b0ef29a6b08a729d4cd4622757b4ab Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:04:53 +0000
Subject: [PATCH 1/2] [mlir][VectorOps] Add fold vector.shuffle ->
 vector.interleave

This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.

i.e.:

```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```

to:

```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 43 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 23 ++++++++++++
 2 files changed, 65 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5be6a628904cdf..59ca634ce3e3d6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2479,11 +2479,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
   }
 };
 
+/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
+/// vector.interleave.
+class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp can't represent a scalable interleave");
+
+    if (resultType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp can't represent an n-D interleave");
+
+    VectorType sourceType = op.getV1VectorType();
+    if (sourceType != op.getV2VectorType() ||
+        ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
+            resultType.getShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp types don't match an interleave");
+    }
+
+    ArrayAttr shuffleMask = op.getMask();
+    int64_t resultVectorSize = resultType.getNumElements();
+    for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
+      int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
+      int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+      if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
+        return rewriter.notifyMatchFailure(op,
+                                           "ShuffleOp mask not interleaving");
+    }
+
+    rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
+    return success();
+  }
+};
+
 } // namespace
 
 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
+  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e5197..4c73a6271786e6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
+  return %0 : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+  return %0 : vector<12xi32>
+}

>From 657461153ca4cb898c1d1110e1a6e979fa2279b6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 6 Mar 2024 13:04:02 +0000
Subject: [PATCH 2/2] Fixups

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 59ca634ce3e3d6..75f6220ad8f3fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2498,8 +2498,7 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
 
     VectorType sourceType = op.getV1VectorType();
     if (sourceType != op.getV2VectorType() ||
-        ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
-            resultType.getShape()) {
+        sourceType.getNumElements() * 2 != resultType.getNumElements()) {
       return rewriter.notifyMatchFailure(
           op, "ShuffleOp types don't match an interleave");
     }



More information about the Mlir-commits mailing list