[Mlir-commits] [mlir] [mlir][VectorOps] Add fold vector.shuffle -> vector.interleave (4/4) (PR #80968)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 21 02:49:25 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.

For example:

```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```

folds to:

```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```

Depends on: #<!-- -->80967

---
Full diff: https://github.com/llvm/llvm-project/pull/80968.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+42-1) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+23) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5be6a628904cdf..59ca634ce3e3d6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2479,11 +2479,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
   }
 };
 
+/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
+/// vector.interleave.
+class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp can't represent a scalable interleave");
+
+    if (resultType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp can't represent an n-D interleave");
+
+    VectorType sourceType = op.getV1VectorType();
+    if (sourceType != op.getV2VectorType() ||
+        ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
+            resultType.getShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp types don't match an interleave");
+    }
+
+    ArrayAttr shuffleMask = op.getMask();
+    int64_t resultVectorSize = resultType.getNumElements();
+    for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
+      int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
+      int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+      if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
+        return rewriter.notifyMatchFailure(op,
+                                           "ShuffleOp mask not interleaving");
+    }
+
+    rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
+    return success();
+  }
+};
+
 } // namespace
 
 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
+  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e5197..4c73a6271786e6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
+  return %0 : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+  return %0 : vector<12xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/80968


More information about the Mlir-commits mailing list