[Mlir-commits] [mlir] [mlir][VectorOps] Add fold vector.shuffle -> vector.interleave (4/4) (PR #80968)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Feb 7 02:23:09 PST 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/80968
This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.
For example:
```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```
folds to:
```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```
Depends on: #80967
>From 4b442eeb377cfcb60a0d05cf6f1ec6ba735c3152 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 09:53:22 +0000
Subject: [PATCH 1/4] [mlir][VectorOps] Add vector.interleave operation
The interleave operation constructs a new vector by interleaving the
elements from the trailing (or final) dimension of two input vectors,
returning a new vector where the trailing dimension is twice the size.
Note that for the n-D case this differs from the interleaving possible
with `vector.shuffle`, which would only operate on the leading
dimension.
Another key difference is this operation supports scalable vectors,
though currently a general LLVM lowering is limited to the case where
only the trailing dimension is scalable.
Example:
```mlir
%0 = vector.interleave %a, %b
: vector<[4]xi32> ; yields vector<[8]xi32>
%1 = vector.interleave %c, %d
: vector<8xi8> ; yields vector<16xi8>
%2 = vector.interleave %e, %f
: vector<f16> ; yields vector<2xf16>
%3 = vector.interleave %g, %h
: vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
%4 = vector.interleave %i, %j
: vector<6x3xf32> ; yields vector<6x6xf32>
```
Note: This change alone does not add any lowerings.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 63 +++++++++++++++++++
mlir/test/Dialect/Vector/ops.mlir | 35 +++++++++++
2 files changed, 98 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bc08f8d07fb0d..6d50b0654bc57 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -478,6 +478,69 @@ def Vector_ShuffleOp :
let hasCanonicalizer = 1;
}
+def Vector_InterleaveOp :
+ Vector_Op<"interleave", [Pure,
+ AllTypesMatch<["lhs", "rhs"]>,
+ TypesMatchWith<
+ "type of 'result' is double the width of the inputs",
+ "lhs", "result",
+ [{
+ [&]() -> ::mlir::VectorType {
+ auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+ ::mlir::VectorType::Builder builder(vectorType);
+ if (vectorType.getRank() == 0) {
+ static constexpr int64_t v2xty_shape[] = { 2 };
+ return builder.setShape(v2xty_shape);
+ }
+ auto lastDim = vectorType.getRank() - 1;
+ return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
+ }()
+ }]>]> {
+ let summary = "constructs a vector by interleaving two input vectors";
+ let description = [{
+ The interleave operation constructs a new vector by interleaving the
+ elements from the trailing (or final) dimension of two input vectors,
+ returning a new vector where the trailing dimension is twice the size.
+
+ Note that for the n-D case this differs from the interleaving possible with
+ `vector.shuffle`, which would only operate on the leading dimension.
+
+ Another key difference is this operation supports scalable vectors, though
+ currently a general LLVM lowering is limited to the case where only the
+ trailing dimension is scalable.
+
+ Example:
+ ```mlir
+ %0 = vector.interleave %a, %b
+ : vector<[4]xi32> ; yields vector<[8]xi32>
+ %1 = vector.interleave %c, %d
+ : vector<8xi8> ; yields vector<16xi8>
+ %2 = vector.interleave %e, %f
+ : vector<f16> ; yields vector<2xf16>
+ %3 = vector.interleave %g, %h
+ : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
+ %4 = vector.interleave %i, %j
+ : vector<6x3xf32> ; yields vector<6x6xf32>
+ ```
+ }];
+
+ let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
+ let results = (outs AnyVector:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return ::llvm::cast<VectorType>(getLhs().getType());
+ }
+ VectorType getResultVectorType() {
+ return ::llvm::cast<VectorType>(getResult().getType());
+ }
+ }];
+}
+
def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
TypesMatchWith<"result type matches element type of vector operand",
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2f8530e7c171a..79a80be4f8b20 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1081,3 +1081,38 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
%min = vector.reduction <minnumf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
return %min: f32
}
+
+// CHECK-LABEL: @interleave_0d
+func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
+ %0 = vector.interleave %a, %b : vector<f32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @interleave_1d
+func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
+ %0 = vector.interleave %a, %b : vector<4xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @interleave_1d_scalable
+func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
+ %0 = vector.interleave %a, %b : vector<[8]xi16>
+ return %0 : vector<[16]xi16>
+}
+
+// CHECK-LABEL: @interleave_2d
+func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
+ %0 = vector.interleave %a, %b : vector<2x8xf32>
+ return %0 : vector<2x16xf32>
+}
+
+// CHECK-LABEL: @interleave_2d_scalable
+func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
+ %0 = vector.interleave %a, %b : vector<2x[2]xf64>
+ return %0 : vector<2x[4]xf64>
+}
>From c2047adb5672ec8eaa05e16f34dbbb794a0eba6c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 09:56:42 +0000
Subject: [PATCH 2/4] [mlir][VectorOps] Add conversion of 1-D vector.interleave
ops to LLVM
The 1-D case directly maps to LLVM intrinsics. The n-D case will be
handled by unrolling to 1-D first (in a later patch).
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 41 ++++++++++++++++++-
.../VectorToLLVM/vector-to-llvm.mlir | 37 +++++++++++++++++
2 files changed, 77 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b66b55ae8d57f..0d9a451d11ca8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1734,6 +1734,44 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};
+/// Conversion pattern for a `vector.interleave`.
+/// This supports fixed-sized vectors and scalable vectors.
+struct VectorInterleaveOpLowering
+ : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = interleaveOp.getResultVectorType();
+ // n-D interleaves should have been lowered already.
+ if (resultType.getRank() != 1)
+ return failure();
+ // If the result is rank 1, then this directly maps to LLVM.
+ if (resultType.isScalable()) {
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
+ interleaveOp, typeConverter->convertType(resultType),
+ adaptor.getLhs(), adaptor.getRhs());
+ return success();
+ }
+ // Lower fixed-size interleaves to a shufflevector. While the
+ // vector.interleave2 intrinsic supports fixed and scalable vectors, the
+ // langref still recommends fixed-vectors use shufflevector, see:
+ // https://llvm.org/docs/LangRef.html#id876.
+ int64_t resultVectorSize = resultType.getNumElements();
+ SmallVector<int32_t> interleaveShuffleMask;
+ interleaveShuffleMask.reserve(resultVectorSize);
+ for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
+ interleaveShuffleMask.push_back(i);
+ interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
+ }
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
+ interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
+ interleaveShuffleMask);
+ return success();
+ }
+};
+
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1758,7 +1796,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
- MaskedReductionOpConversion>(converter);
+ MaskedReductionOpConversion, VectorInterleaveOpLowering>(
+ converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1c13b16dfd9af..a46f2e101f3c3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2460,3 +2460,40 @@ func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64>
%res = vector.broadcast %f : f64 to vector<3x[2]xf64>
return %res : vector<3x[2]xf64>
}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_0d
+// CHECK-SAME: %[[LHS:.*]]: vector<i8>, %[[RHS:.*]]: vector<i8>)
+func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8> {
+ // CHECK: %[[LHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<i8> to vector<1xi8>
+ // CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<i8> to vector<1xi8>
+ // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.interleave %a, %b : vector<i8>
+ return %0 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_1d
+// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>)
+func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32>
+{
+ // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.interleave %a, %b : vector<8xf32>
+ return %0 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_1d_scalable
+// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>)
+func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32>
+{
+ // CHECK: %[[ZIP:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.interleave %a, %b : vector<[4]xi32>
+ return %0 : vector<[8]xi32>
+}
>From fa7b99005f31ce050857ec0c8cc7c075f6b4140a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:00:45 +0000
Subject: [PATCH 3/4] [mlir][VectorOps] Add unrolling for n-D vector.interleave
ops
This unrolls n-D vector.interleave ops like:
```mlir
vector.interleave %i, %j : vector<6x3xf32>
```
To a sequence of 1-D operations, which can then be directly lowered to
LLVM.
---
.../Vector/Transforms/LoweringPatterns.h | 8 +++
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Transforms/LowerVectorInterleave.cpp | 64 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 48 ++++++++++++++
5 files changed, 122 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 57b39f5f52c6d..1cd3bab46396e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [InterleaveOpLowering]
+/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
+/// InterleaveOp until dim 1.
+void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ff8e78a668e0f..e3a436c4a9400 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
+ populateVectorInterleaveLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns,
VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef..f221b7462dfd7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
LowerVectorGather.cpp
+ LowerVectorInterleave.cpp
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
new file mode 100644
index 0000000000000..0ca38eba942a5
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -0,0 +1,64 @@
+//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.interleave' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-interleave-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Progressive lowering of InterleaveOp.
+class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ // 1-D vector.interleave ops can be directly lowered to LLVM (later).
+ if (resultType.getRank() == 1)
+ return failure();
+
+ // Below we unroll the leading (or front) dimension. If that dimension is
+ // scalable we can't unroll it.
+ if (resultType.getScalableDims().front())
+ return failure();
+
+ // n-D case: Unroll the leading dimension.
+ auto loc = op.getLoc();
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
+ Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
+ Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+ Value interleave =
+ rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+ result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorInterleaveLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a46f2e101f3c3..3cbca65472fb6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2497,3 +2497,51 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
%0 = vector.interleave %a, %b : vector<[4]xi32>
return %0 : vector<[8]xi32>
}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d
+// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+ // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+ // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi8>
+ // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x6xi8> to !llvm.array<2 x vector<6xi8>>
+ // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+ // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+ // CHECK: %[[ZIM_DIM_0:.*]] = llvm.shufflevector %[[LHS_DIM_0]], %[[RHS_DIM_0]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+ // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<6xi8>>
+ // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %[[LHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+ // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %[[RHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+ // CHECK: %[[ZIM_DIM_1:.*]] = llvm.shufflevector %[[LHS_DIM_1]], %[[RHS_DIM_1]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+ // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIM_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<6xi8>>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<6xi8>> to vector<2x6xi8>
+ // CHECK: return %[[RES]]
+ %0 = vector.interleave %a, %b : vector<2x3xi8>
+ return %0 : vector<2x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+ // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg0 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+ // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg1 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x[16]xi16>
+ // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[16]xi16> to !llvm.array<2 x vector<[16]xi16>>
+ // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+ // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+ // CHECK: %[[ZIM_DIM_0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_0]], %[[RHS_DIM_0]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+ // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<[16]xi16>>
+ // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %0[1] : !llvm.array<2 x vector<[8]xi16>>
+ // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %1[1] : !llvm.array<2 x vector<[8]xi16>>
+ // CHECK: %[[ZIP_DIM_1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_1]], %[[RHS_DIM_1]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+ // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIP_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<[16]xi16>>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<[16]xi16>> to vector<2x[16]xi16>
+ // CHECK: return %[[RES]]
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ return %0 : vector<2x[16]xi16>
+}
>From 32c9285ae02407cd8a3a6c09f5ae2de158740365 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:04:53 +0000
Subject: [PATCH 4/4] [mlir][VectorOps] Add fold vector.shuffle ->
vector.interleave
This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.
i.e.:
```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```
to:
```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 43 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 23 ++++++++++++
2 files changed, 65 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 452354413e883..084348e68270c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2478,11 +2478,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
}
};
+/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
+/// vector.interleave.
+class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ if (resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp can't represent a scalable interleave");
+
+ if (resultType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp can't represent an n-D interleave");
+
+ VectorType sourceType = op.getV1VectorType();
+ if (sourceType != op.getV2VectorType() ||
+ ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
+ resultType.getShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp types don't match an interleave");
+ }
+
+ ArrayAttr shuffleMask = op.getMask();
+ int64_t resultVectorSize = resultType.getNumElements();
+ for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
+ int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
+ int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+ if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
+ return rewriter.notifyMatchFailure(op,
+ "ShuffleOp mask not interleaving");
+ }
+
+ rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
+ return success();
+ }
+};
+
} // namespace
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
+ results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e519..4c73a6271786e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
+// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
+ return %0 : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
+// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+ return %0 : vector<12xi32>
+}
More information about the Mlir-commits
mailing list