[Mlir-commits] [mlir] 8d163e5 - [mlir][Vector] Add 16x16 strategy to vector.transpose lowering.
Hanhan Wang
llvmlistbot at llvm.org
Sun Apr 23 11:05:50 PDT 2023
Author: Hanhan Wang
Date: 2023-04-23T11:05:41-07:00
New Revision: 8d163e5045073a5ac570225cc8e14cc9f6d72f09
URL: https://github.com/llvm/llvm-project/commit/8d163e5045073a5ac570225cc8e14cc9f6d72f09
DIFF: https://github.com/llvm/llvm-project/commit/8d163e5045073a5ac570225cc8e14cc9f6d72f09.diff
LOG: [mlir][Vector] Add 16x16 strategy to vector.transpose lowering.
It adds a `shuffle_16x16` strategy LowerVectorTranspose and renames `shuffle` to `shuffle_1d`. The idea is similar to 8x8 cases in x86Vector::avx2. The general algorithm is:
```
interleave 32-bit lanes using
8x _mm512_unpacklo_epi32
8x _mm512_unpackhi_epi32
interleave 64-bit lanes using
8x _mm512_unpacklo_epi64
8x _mm512_unpackhi_epi64
permute 128-bit lanes using
16x _mm512_shuffle_i32x4
permute 256-bit lanes using again
16x _mm512_shuffle_i32x4
```
After the first stage, they got transposed to
```
0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29
2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31
32 48 33 49 ...
34 50 35 51 ...
64 80 65 81 ...
...
```
After the second stage, they got transposed to
```
0 16 32 48 ...
1 17 33 49 ...
2 18 34 49 ...
3 19 35 51 ...
64 80 96 112 ...
65 81 97 114 ...
66 82 98 113 ...
67 83 99 115 ...
...
```
After the thrid stage, they got transposed to
```
0 16 32 48 8 24 40 56 64 80 96 112 ...
1 17 33 49 ...
2 18 34 50 ...
3 19 35 51 ...
4 20 36 52 ...
5 21 37 53 ...
6 22 38 54 ...
7 23 39 55 ...
128 144 160 176 ...
129 145 161 177 ...
...
```
After the last stage, they got transposed to
```
0 16 32 48 64 80 96 112 ... 240
1 17 33 49 66 81 97 113 ... 241
2 18 34 50 67 82 98 114 ... 242
...
15 31 47 63 79 96 111 127 ... 255
```
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D148685
Added:
mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
mlir/test/Dialect/LLVM/transform-e2e.mlir
mlir/test/Dialect/Vector/transform-vector.mlir
mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
index fb1f2ab717687..ef0951ab1d166 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
@@ -18,14 +18,17 @@ def VectorTransposeLowering_Elementwise:
// intrinsics.
def VectorTransposeLowering_FlatTranspose:
I32EnumAttrCase<"Flat", 1, "flat_transpose">;
-// Lower 2-D transpose to `vector.shuffle`.
-def VectorTransposeLowering_Shuffle:
- I32EnumAttrCase<"Shuffle", 2, "shuffle">;
+// Lower 2-D transpose to `vector.shuffle` on 1-D vector.
+def VectorTransposeLowering_Shuffle1D:
+ I32EnumAttrCase<"Shuffle1D", 2, "shuffle_1d">;
+// Lower 2-D transpose to `vector.shuffle` on 16x16 vector.
+def VectorTransposeLowering_Shuffle16x16:
+ I32EnumAttrCase<"Shuffle16x16", 3, "shuffle_16x16">;
def VectorTransposeLoweringAttr : I32EnumAttr<
"VectorTransposeLowering",
"control the lowering of `vector.transpose` operations.",
[VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
- VectorTransposeLowering_Shuffle]> {
+ VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> {
let cppNamespace = "::mlir::vector";
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 2888e347ce441..1408c03f21456 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -52,6 +52,253 @@ static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
}
+/// Returns true if the lowering option is a vector shuffle based approach.
+static bool isShuffleLike(VectorTransposeLowering lowering) {
+ return lowering == VectorTransposeLowering::Shuffle1D ||
+ lowering == VectorTransposeLowering::Shuffle16x16;
+}
+
+/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
+/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
+/// create the mask for `numBits` bits vector. The `numBits` have to be a
+/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
+/// 512, there should be 16 elements in the final result. It constructs the
+/// below mask to get the unpack elements.
+/// [0, 1, 16, 17,
+/// 0+4, 1+4, 16+4, 17+4,
+/// 0+8, 1+8, 16+8, 17+8,
+/// 0+12, 1+12, 16+12, 17+12]
+static SmallVector<int64_t>
+getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
+ assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
+ int numElem = numBits / 32;
+ SmallVector<int64_t> res;
+ for (int i = 0; i < numElem; i += 4)
+ for (int64_t v : vals)
+ res.push_back(v + i);
+ return res;
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+/// vector.shuffle on v1, v2, [0, 1, 16, 17,
+/// 0+4, 1+4, 16+4, 17+4,
+/// 0+8, 1+8, 16+8, 17+8,
+/// 0+12, 1+12, 16+12, 17+12].
+static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
+ int numBits) {
+ int numElem = numBits / 32;
+ return b.create<vector::ShuffleOp>(
+ v1, v2,
+ getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+/// vector.shuffle, v1, v2, [2, 3, 18, 19,
+/// 2+4, 3+4, 18+4, 19+4,
+/// 2+8, 3+8, 18+8, 19+8,
+/// 2+12, 3+12, 18+12, 19+12].
+static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
+ int numBits) {
+ int numElem = numBits / 32;
+ return b.create<vector::ShuffleOp>(
+ v1, v2,
+ getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
+ numBits));
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+/// vector.shuffle, v1, v2, [0, 16, 1, 17,
+/// 0+4, 16+4, 1+4, 17+4,
+/// 0+8, 16+8, 1+8, 17+8,
+/// 0+12, 16+12, 1+12, 17+12].
+static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
+ int numBits) {
+ int numElem = numBits / 32;
+ auto shuffle = b.create<vector::ShuffleOp>(
+ v1, v2,
+ getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
+ return shuffle;
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+/// vector.shuffle, v1, v2, [2, 18, 3, 19,
+/// 2+4, 18+4, 3+4, 19+4,
+/// 2+8, 18+8, 3+8, 19+8,
+/// 2+12, 18+12, 3+12, 19+12].
+static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
+ int numBits) {
+ int numElem = numBits / 32;
+ return b.create<vector::ShuffleOp>(
+ v1, v2,
+ getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
+ numBits));
+}
+
+/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
+/// elements) selected by `mask` from `v1` and `v2`. I.e.,
+///
+/// DEFINE SELECT4(src, control) {
+/// CASE(control[1:0]) OF
+/// 0: tmp[127:0] := src[127:0]
+/// 1: tmp[127:0] := src[255:128]
+/// 2: tmp[127:0] := src[383:256]
+/// 3: tmp[127:0] := src[511:384]
+/// ESAC
+/// RETURN tmp[127:0]
+/// }
+/// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
+/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
+/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
+/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
+static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
+ uint8_t mask) {
+ assert(v1.getType().cast<VectorType>().getShape()[0] == 16 &&
+ "expected a vector with length=16");
+ SmallVector<int64_t> shuffleMask;
+ auto appendToMask = [&](int64_t base, uint8_t control) {
+ switch (control) {
+ case 0:
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
+ base + 2, base + 3});
+ break;
+ case 1:
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
+ base + 6, base + 7});
+ break;
+ case 2:
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
+ base + 10, base + 11});
+ break;
+ case 3:
+ llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
+ base + 14, base + 15});
+ break;
+ default:
+ llvm_unreachable("control > 3 : overflow");
+ }
+ };
+ uint8_t b01 = mask & 0x3;
+ uint8_t b23 = (mask >> 2) & 0x3;
+ uint8_t b45 = (mask >> 4) & 0x3;
+ uint8_t b67 = (mask >> 6) & 0x3;
+ appendToMask(0, b01);
+ appendToMask(0, b23);
+ appendToMask(16, b45);
+ appendToMask(16, b67);
+ return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
+/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
+/// 1-D vector and have `m`x`n` elements.
+static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
+ SmallVector<int64_t> mask;
+ mask.reserve(m * n);
+ for (int64_t j = 0; j < n; ++j)
+ for (int64_t i = 0; i < m; ++i)
+ mask.push_back(i * n + j);
+ return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
+}
+
+/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
+/// expected to be a 16x16 vector.
+static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
+ int n) {
+ ImplicitLocOpBuilder b(source.getLoc(), builder);
+ SmallVector<Value> vs;
+ for (int64_t i = 0; i < m; ++i)
+ vs.push_back(b.create<vector::ExtractOp>(source, i));
+
+ // Interleave 32-bit lanes using
+ // 8x _mm512_unpacklo_epi32
+ // 8x _mm512_unpackhi_epi32
+ Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
+ Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
+ Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
+ Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
+ Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
+ Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
+ Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
+ Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
+ Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
+ Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
+ Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
+ Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
+ Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
+ Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
+ Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
+ Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
+
+ // Interleave 64-bit lanes using
+ // 8x _mm512_unpacklo_epi64
+ // 8x _mm512_unpackhi_epi64
+ Value r0 = createUnpackLoPd(b, t0, t2, 512);
+ Value r1 = createUnpackHiPd(b, t0, t2, 512);
+ Value r2 = createUnpackLoPd(b, t1, t3, 512);
+ Value r3 = createUnpackHiPd(b, t1, t3, 512);
+ Value r4 = createUnpackLoPd(b, t4, t6, 512);
+ Value r5 = createUnpackHiPd(b, t4, t6, 512);
+ Value r6 = createUnpackLoPd(b, t5, t7, 512);
+ Value r7 = createUnpackHiPd(b, t5, t7, 512);
+ Value r8 = createUnpackLoPd(b, t8, ta, 512);
+ Value r9 = createUnpackHiPd(b, t8, ta, 512);
+ Value ra = createUnpackLoPd(b, t9, tb, 512);
+ Value rb = createUnpackHiPd(b, t9, tb, 512);
+ Value rc = createUnpackLoPd(b, tc, te, 512);
+ Value rd = createUnpackHiPd(b, tc, te, 512);
+ Value re = createUnpackLoPd(b, td, tf, 512);
+ Value rf = createUnpackHiPd(b, td, tf, 512);
+
+ // Permute 128-bit lanes using
+ // 16x _mm512_shuffle_i32x4
+ t0 = create4x128BitSuffle(b, r0, r4, 0x88);
+ t1 = create4x128BitSuffle(b, r1, r5, 0x88);
+ t2 = create4x128BitSuffle(b, r2, r6, 0x88);
+ t3 = create4x128BitSuffle(b, r3, r7, 0x88);
+ t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
+ t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
+ t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
+ t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
+ t8 = create4x128BitSuffle(b, r8, rc, 0x88);
+ t9 = create4x128BitSuffle(b, r9, rd, 0x88);
+ ta = create4x128BitSuffle(b, ra, re, 0x88);
+ tb = create4x128BitSuffle(b, rb, rf, 0x88);
+ tc = create4x128BitSuffle(b, r8, rc, 0xdd);
+ td = create4x128BitSuffle(b, r9, rd, 0xdd);
+ te = create4x128BitSuffle(b, ra, re, 0xdd);
+ tf = create4x128BitSuffle(b, rb, rf, 0xdd);
+
+ // Permute 256-bit lanes using again
+ // 16x _mm512_shuffle_i32x4
+ vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
+ vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
+ vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
+ vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
+ vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
+ vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
+ vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
+ vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
+ vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
+ vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
+ vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
+ vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
+ vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
+ vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
+ vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
+ vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
+
+ auto reshInputType = VectorType::get(
+ {m, n}, source.getType().cast<VectorType>().getElementType());
+ Value res =
+ b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
+ for (int64_t i = 0; i < m; ++i)
+ res = b.create<vector::InsertOp>(vs[i], res, i);
+ return res;
+}
+
namespace {
/// Progressive lowering of TransposeOp.
/// One:
@@ -84,8 +331,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
- if (vectorTransformOptions.vectorTransposeLowering ==
- vector::VectorTransposeLowering::Shuffle &&
+ if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
@@ -145,10 +391,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransformsOptions vectorTransformOptions;
};
-/// Rewrite a 2-D vector.transpose as a sequence of:
+/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
+/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
/// vector.shuffle
/// vector.shape_cast 1D -> 2D
+/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
+/// ops on 16xf32 vectors.
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
@@ -174,24 +423,28 @@ class TransposeOp2DToShuffleLowering
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
- if (vectorTransformOptions.vectorTransposeLowering !=
- VectorTransposeLowering::Shuffle)
- return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
-
+ Value res;
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
- Value casted = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({m * n}, srcType.getElementType()),
- op.getVector());
- SmallVector<int64_t> mask;
- mask.reserve(m * n);
- for (int64_t j = 0; j < n; ++j)
- for (int64_t i = 0; i < m; ++i)
- mask.push_back(i * n + j);
-
- Value shuffled =
- rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
+ switch (vectorTransformOptions.vectorTransposeLowering) {
+ case VectorTransposeLowering::Shuffle1D: {
+ Value casted = rewriter.create<vector::ShapeCastOp>(
+ loc, VectorType::get({m * n}, srcType.getElementType()),
+ op.getVector());
+ res = transposeToShuffle1D(rewriter, casted, m, n);
+ break;
+ }
+ case VectorTransposeLowering::Shuffle16x16:
+ if (m != 16 || n != 16)
+ return failure();
+ res = transposeToShuffle16x16(rewriter, op.getVector(), m, n);
+ break;
+ case VectorTransposeLowering::EltWise:
+ case VectorTransposeLowering::Flat:
+ return failure();
+ }
+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- op, op.getResultVectorType(), shuffled);
+ op, op.getResultVectorType(), res);
return success();
}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 427e6075c5dcb..0ca8206407774 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -54,6 +54,6 @@ transform.sequence failures(propagate) {
: (!pdl.operation) -> !pdl.operation
%func_8 = transform.vector.lower_transpose %func_7
- lowering_strategy = "shuffle"
+ lowering_strategy = "shuffle_1d"
: (!pdl.operation) -> !pdl.operation
}
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index cd933193299a9..2fbee42109706 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -57,6 +57,6 @@ transform.sequence failures(propagate) {
: (!pdl.operation) -> !pdl.operation
%func_8 = transform.vector.lower_transpose %func_7
- lowering_strategy = "shuffle"
+ lowering_strategy = "shuffle_1d"
: (!pdl.operation) -> !pdl.operation
}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index a2fc23386c25d..44bba6fbed982 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -100,7 +100,7 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.lower_transpose %module_op
- lowering_strategy = "shuffle"
+ lowering_strategy = "shuffle_1d"
: (!pdl.operation) -> !pdl.operation
}
@@ -609,3 +609,81 @@ transform.sequence failures(propagate) {
avx2_lowering_strategy = true
: (!pdl.operation) -> !pdl.operation
}
+
+// -----
+
+func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16xf32> {
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+ return %0 : vector<16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ transform.vector.lower_transpose %module_op
+ lowering_strategy = "shuffle_16x16"
+ : (!pdl.operation) -> !pdl.operation
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
new file mode 100644
index 0000000000000..147d9a7f2ff4c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -convert-scf-to-cf \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+ %in = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0], [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0], [48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0], [64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0], [80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0], [96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0], [112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0], [128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0], [144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0], [160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, 174.0, 175.0], [176.0, 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0], [192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0], [208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0], [224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, 239.0], [240.0, 241.0, 242.0, 243.0, 244.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0]]> : vector<16x16xf32>
+ %0 = vector.transpose %in, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+ vector.print %0 : vector<16x16xf32>
+ // CHECK: ( ( 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240 ),
+ // CHECK-SAME: ( 1, 17, 33, 49, 65, 81, 97, 113, 129, 145, 161, 177, 193, 209, 225, 241 ),
+ // CHECK-SAME: ( 2, 18, 34, 50, 66, 82, 98, 114, 130, 146, 162, 178, 194, 210, 226, 242 ),
+ // CHECK-SAME: ( 3, 19, 35, 51, 67, 83, 99, 115, 131, 147, 163, 179, 195, 211, 227, 243 ),
+ // CHECK-SAME: ( 4, 20, 36, 52, 68, 84, 100, 116, 132, 148, 164, 180, 196, 212, 228, 244 ),
+ // CHECK-SAME: ( 5, 21, 37, 53, 69, 85, 101, 117, 133, 149, 165, 181, 197, 213, 229, 245 ),
+ // CHECK-SAME: ( 6, 22, 38, 54, 70, 86, 102, 118, 134, 150, 166, 182, 198, 214, 230, 246 ),
+ // CHECK-SAME: ( 7, 23, 39, 55, 71, 87, 103, 119, 135, 151, 167, 183, 199, 215, 231, 247 ),
+ // CHECK-SAME: ( 8, 24, 40, 56, 72, 88, 104, 120, 136, 152, 168, 184, 200, 216, 232, 248 ),
+ // CHECK-SAME: ( 9, 25, 41, 57, 73, 89, 105, 121, 137, 153, 169, 185, 201, 217, 233, 249 ),
+ // CHECK-SAME: ( 10, 26, 42, 58, 74, 90, 106, 122, 138, 154, 170, 186, 202, 218, 234, 250 ),
+ // CHECK-SAME: ( 11, 27, 43, 59, 75, 91, 107, 123, 139, 155, 171, 187, 203, 219, 235, 251 ),
+ // CHECK-SAME: ( 12, 28, 44, 60, 76, 92, 108, 124, 140, 156, 172, 188, 204, 220, 236, 252 ),
+ // CHECK-SAME: ( 13, 29, 45, 61, 77, 93, 109, 125, 141, 157, 173, 189, 205, 221, 237, 253 ),
+ // CHECK-SAME: ( 14, 30, 46, 62, 78, 94, 110, 126, 142, 158, 174, 190, 206, 222, 238, 254 ),
+ // CHECK-SAME: ( 15, 31, 47, 63, 79, 95, 111, 127, 143, 159, 175, 191, 207, 223, 239, 255 ) )
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ transform.vector.lower_transpose %module_op
+ lowering_strategy = "shuffle_16x16"
+ : (!pdl.operation) -> !pdl.operation
+}
+
More information about the Mlir-commits
mailing list