[Mlir-commits] [mlir] 8d163e5 - [mlir][Vector] Add 16x16 strategy to vector.transpose lowering.

Hanhan Wang llvmlistbot at llvm.org
Sun Apr 23 11:05:50 PDT 2023


Author: Hanhan Wang
Date: 2023-04-23T11:05:41-07:00
New Revision: 8d163e5045073a5ac570225cc8e14cc9f6d72f09

URL: https://github.com/llvm/llvm-project/commit/8d163e5045073a5ac570225cc8e14cc9f6d72f09
DIFF: https://github.com/llvm/llvm-project/commit/8d163e5045073a5ac570225cc8e14cc9f6d72f09.diff

LOG: [mlir][Vector] Add 16x16 strategy to vector.transpose lowering.

It adds a `shuffle_16x16` strategy LowerVectorTranspose and renames `shuffle` to `shuffle_1d`. The idea is similar to 8x8 cases in x86Vector::avx2. The general algorithm is:

```
interleave 32-bit lanes using
    8x _mm512_unpacklo_epi32
    8x _mm512_unpackhi_epi32
interleave 64-bit lanes using
    8x _mm512_unpacklo_epi64
    8x _mm512_unpackhi_epi64
permute 128-bit lanes using
   16x _mm512_shuffle_i32x4
permute 256-bit lanes using again
   16x _mm512_shuffle_i32x4
```

After the first stage, they got transposed to

```
 0  16   1  17   4  20   5  21   8  24   9  25  12  28  13  29
 2  18   3  19   6  22   7  23  10  26  11  27  14  30  15  31
32  48  33  49 ...
34  50  35  51 ...
64  80  65  81 ...
...
```

After the second stage, they got transposed to

```
 0  16  32  48 ...
 1  17  33  49 ...
 2  18  34  49 ...
 3  19  35  51 ...
64  80  96 112 ...
65  81  97 114 ...
66  82  98 113 ...
67  83  99 115 ...
...
```

After the thrid stage, they got transposed to

```
  0  16  32  48   8  24  40  56  64  80  96  112 ...
  1  17  33  49 ...
  2  18  34  50 ...
  3  19  35  51 ...
  4  20  36  52 ...
  5  21  37  53 ...
  6  22  38  54 ...
  7  23  39  55 ...
128 144 160 176 ...
129 145 161 177 ...
...
```

After the last stage, they got transposed to

```
0  16  32  48  64  80  96 112 ... 240
1  17  33  49  66  81  97 113 ... 241
2  18  34  50  67  82  98 114 ... 242
...
15  31  47  63  79  96 111 127 ... 255
```

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D148685

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
    mlir/test/Dialect/LLVM/transform-e2e.mlir
    mlir/test/Dialect/Vector/transform-vector.mlir
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
index fb1f2ab717687..ef0951ab1d166 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
@@ -18,14 +18,17 @@ def VectorTransposeLowering_Elementwise:
 // intrinsics.
 def VectorTransposeLowering_FlatTranspose:
   I32EnumAttrCase<"Flat",  1, "flat_transpose">;
-// Lower 2-D transpose to `vector.shuffle`.
-def VectorTransposeLowering_Shuffle:
-  I32EnumAttrCase<"Shuffle",  2, "shuffle">;
+// Lower 2-D transpose to `vector.shuffle` on 1-D vector.
+def VectorTransposeLowering_Shuffle1D:
+  I32EnumAttrCase<"Shuffle1D",  2, "shuffle_1d">;
+// Lower 2-D transpose to `vector.shuffle` on 16x16 vector.
+def VectorTransposeLowering_Shuffle16x16:
+  I32EnumAttrCase<"Shuffle16x16",  3, "shuffle_16x16">;
 def VectorTransposeLoweringAttr : I32EnumAttr<
     "VectorTransposeLowering",
     "control the lowering of `vector.transpose` operations.",
     [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
-     VectorTransposeLowering_Shuffle]> {
+     VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> {
   let cppNamespace = "::mlir::vector";
 }
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 2888e347ce441..1408c03f21456 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -52,6 +52,253 @@ static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
 }
 
+/// Returns true if the lowering option is a vector shuffle based approach.
+static bool isShuffleLike(VectorTransposeLowering lowering) {
+  return lowering == VectorTransposeLowering::Shuffle1D ||
+         lowering == VectorTransposeLowering::Shuffle16x16;
+}
+
+/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
+/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
+/// create the mask for `numBits` bits vector. The `numBits` have to be a
+/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
+/// 512, there should be 16 elements in the final result. It constructs the
+/// below mask to get the unpack elements.
+///   [0,    1,    16,    17,
+///    0+4,  1+4,  16+4,  17+4,
+///    0+8,  1+8,  16+8,  17+8,
+///    0+12, 1+12, 16+12, 17+12]
+static SmallVector<int64_t>
+getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
+  assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
+  int numElem = numBits / 32;
+  SmallVector<int64_t> res;
+  for (int i = 0; i < numElem; i += 4)
+    for (int64_t v : vals)
+      res.push_back(v + i);
+  return res;
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+///   vector.shuffle on v1, v2, [0,    1,    16,    17,
+///                              0+4,  1+4,  16+4,  17+4,
+///                              0+8,  1+8,  16+8,  17+8,
+///                              0+12, 1+12, 16+12, 17+12].
+static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                              int numBits) {
+  int numElem = numBits / 32;
+  return b.create<vector::ShuffleOp>(
+      v1, v2,
+      getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+///   vector.shuffle, v1, v2, [2,    3,    18,    19,
+///                            2+4,  3+4,  18+4,  19+4,
+///                            2+8,  3+8,  18+8,  19+8,
+///                            2+12, 3+12, 18+12, 19+12].
+static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                              int numBits) {
+  int numElem = numBits / 32;
+  return b.create<vector::ShuffleOp>(
+      v1, v2,
+      getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
+                                     numBits));
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+///   vector.shuffle, v1, v2, [0,    16,    1,    17,
+///                            0+4,  16+4,  1+4,  17+4,
+///                            0+8,  16+8,  1+8,  17+8,
+///                            0+12, 16+12, 1+12, 17+12].
+static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                              int numBits) {
+  int numElem = numBits / 32;
+  auto shuffle = b.create<vector::ShuffleOp>(
+      v1, v2,
+      getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
+  return shuffle;
+}
+
+/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
+/// example, if it is targeting 512 bit vector, returns
+///   vector.shuffle, v1, v2, [2,    18,    3,    19,
+///                            2+4,  18+4,  3+4,  19+4,
+///                            2+8,  18+8,  3+8,  19+8,
+///                            2+12, 18+12, 3+12, 19+12].
+static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                              int numBits) {
+  int numElem = numBits / 32;
+  return b.create<vector::ShuffleOp>(
+      v1, v2,
+      getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
+                                     numBits));
+}
+
+/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
+/// elements) selected by `mask` from `v1` and `v2`. I.e.,
+///
+/// DEFINE SELECT4(src, control) {
+///	CASE(control[1:0]) OF
+///	0:	tmp[127:0] := src[127:0]
+///	1:	tmp[127:0] := src[255:128]
+///	2:	tmp[127:0] := src[383:256]
+///	3:	tmp[127:0] := src[511:384]
+///	ESAC
+///	RETURN tmp[127:0]
+/// }
+/// dst[127:0]   := SELECT4(v1[511:0], mask[1:0])
+/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
+/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
+/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
+static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                                  uint8_t mask) {
+  assert(v1.getType().cast<VectorType>().getShape()[0] == 16 &&
+         "expected a vector with length=16");
+  SmallVector<int64_t> shuffleMask;
+  auto appendToMask = [&](int64_t base, uint8_t control) {
+    switch (control) {
+    case 0:
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
+                                                        base + 2, base + 3});
+      break;
+    case 1:
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
+                                                        base + 6, base + 7});
+      break;
+    case 2:
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
+                                                        base + 10, base + 11});
+      break;
+    case 3:
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
+                                                        base + 14, base + 15});
+      break;
+    default:
+      llvm_unreachable("control > 3 : overflow");
+    }
+  };
+  uint8_t b01 = mask & 0x3;
+  uint8_t b23 = (mask >> 2) & 0x3;
+  uint8_t b45 = (mask >> 4) & 0x3;
+  uint8_t b67 = (mask >> 6) & 0x3;
+  appendToMask(0, b01);
+  appendToMask(0, b23);
+  appendToMask(16, b45);
+  appendToMask(16, b67);
+  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
+/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
+/// 1-D vector and have `m`x`n` elements.
+static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
+  SmallVector<int64_t> mask;
+  mask.reserve(m * n);
+  for (int64_t j = 0; j < n; ++j)
+    for (int64_t i = 0; i < m; ++i)
+      mask.push_back(i * n + j);
+  return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
+}
+
+/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
+/// expected to be a 16x16 vector.
+static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
+                                     int n) {
+  ImplicitLocOpBuilder b(source.getLoc(), builder);
+  SmallVector<Value> vs;
+  for (int64_t i = 0; i < m; ++i)
+    vs.push_back(b.create<vector::ExtractOp>(source, i));
+
+  // Interleave 32-bit lanes using
+  //   8x _mm512_unpacklo_epi32
+  //   8x _mm512_unpackhi_epi32
+  Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
+  Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
+  Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
+  Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
+  Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
+  Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
+  Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
+  Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
+  Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
+  Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
+  Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
+  Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
+  Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
+  Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
+  Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
+  Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
+
+  // Interleave 64-bit lanes using
+  //   8x _mm512_unpacklo_epi64
+  //   8x _mm512_unpackhi_epi64
+  Value r0 = createUnpackLoPd(b, t0, t2, 512);
+  Value r1 = createUnpackHiPd(b, t0, t2, 512);
+  Value r2 = createUnpackLoPd(b, t1, t3, 512);
+  Value r3 = createUnpackHiPd(b, t1, t3, 512);
+  Value r4 = createUnpackLoPd(b, t4, t6, 512);
+  Value r5 = createUnpackHiPd(b, t4, t6, 512);
+  Value r6 = createUnpackLoPd(b, t5, t7, 512);
+  Value r7 = createUnpackHiPd(b, t5, t7, 512);
+  Value r8 = createUnpackLoPd(b, t8, ta, 512);
+  Value r9 = createUnpackHiPd(b, t8, ta, 512);
+  Value ra = createUnpackLoPd(b, t9, tb, 512);
+  Value rb = createUnpackHiPd(b, t9, tb, 512);
+  Value rc = createUnpackLoPd(b, tc, te, 512);
+  Value rd = createUnpackHiPd(b, tc, te, 512);
+  Value re = createUnpackLoPd(b, td, tf, 512);
+  Value rf = createUnpackHiPd(b, td, tf, 512);
+
+  // Permute 128-bit lanes using
+  //   16x _mm512_shuffle_i32x4
+  t0 = create4x128BitSuffle(b, r0, r4, 0x88);
+  t1 = create4x128BitSuffle(b, r1, r5, 0x88);
+  t2 = create4x128BitSuffle(b, r2, r6, 0x88);
+  t3 = create4x128BitSuffle(b, r3, r7, 0x88);
+  t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
+  t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
+  t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
+  t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
+  t8 = create4x128BitSuffle(b, r8, rc, 0x88);
+  t9 = create4x128BitSuffle(b, r9, rd, 0x88);
+  ta = create4x128BitSuffle(b, ra, re, 0x88);
+  tb = create4x128BitSuffle(b, rb, rf, 0x88);
+  tc = create4x128BitSuffle(b, r8, rc, 0xdd);
+  td = create4x128BitSuffle(b, r9, rd, 0xdd);
+  te = create4x128BitSuffle(b, ra, re, 0xdd);
+  tf = create4x128BitSuffle(b, rb, rf, 0xdd);
+
+  // Permute 256-bit lanes using again
+  //   16x _mm512_shuffle_i32x4
+  vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
+  vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
+  vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
+  vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
+  vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
+  vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
+  vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
+  vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
+  vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
+  vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
+  vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
+  vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
+  vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
+  vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
+  vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
+  vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
+
+  auto reshInputType = VectorType::get(
+      {m, n}, source.getType().cast<VectorType>().getElementType());
+  Value res =
+      b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
+  for (int64_t i = 0; i < m; ++i)
+    res = b.create<vector::InsertOp>(vs[i], res, i);
+  return res;
+}
+
 namespace {
 /// Progressive lowering of TransposeOp.
 /// One:
@@ -84,8 +331,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     for (auto attr : op.getTransp())
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            vector::VectorTransposeLowering::Shuffle &&
+    if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
@@ -145,10 +391,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransformsOptions vectorTransformOptions;
 };
 
-/// Rewrite a 2-D vector.transpose as a sequence of:
+/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
+/// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
 ///   vector.shuffle
 ///   vector.shape_cast 1D -> 2D
+/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
+/// ops on 16xf32 vectors.
 class TransposeOp2DToShuffleLowering
     : public OpRewritePattern<vector::TransposeOp> {
 public:
@@ -174,24 +423,28 @@ class TransposeOp2DToShuffleLowering
     if (transp[0] != 1 && transp[1] != 0)
       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
 
-    if (vectorTransformOptions.vectorTransposeLowering !=
-        VectorTransposeLowering::Shuffle)
-      return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
-
+    Value res;
     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
-    Value casted = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({m * n}, srcType.getElementType()),
-        op.getVector());
-    SmallVector<int64_t> mask;
-    mask.reserve(m * n);
-    for (int64_t j = 0; j < n; ++j)
-      for (int64_t i = 0; i < m; ++i)
-        mask.push_back(i * n + j);
-
-    Value shuffled =
-        rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
+    switch (vectorTransformOptions.vectorTransposeLowering) {
+    case VectorTransposeLowering::Shuffle1D: {
+      Value casted = rewriter.create<vector::ShapeCastOp>(
+          loc, VectorType::get({m * n}, srcType.getElementType()),
+          op.getVector());
+      res = transposeToShuffle1D(rewriter, casted, m, n);
+      break;
+    }
+    case VectorTransposeLowering::Shuffle16x16:
+      if (m != 16 || n != 16)
+        return failure();
+      res = transposeToShuffle16x16(rewriter, op.getVector(), m, n);
+      break;
+    case VectorTransposeLowering::EltWise:
+    case VectorTransposeLowering::Flat:
+      return failure();
+    }
+
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-        op, op.getResultVectorType(), shuffled);
+        op, op.getResultVectorType(), res);
 
     return success();
   }

diff  --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 427e6075c5dcb..0ca8206407774 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -54,6 +54,6 @@ transform.sequence failures(propagate) {
     : (!pdl.operation) -> !pdl.operation
 
   %func_8 = transform.vector.lower_transpose %func_7
-    lowering_strategy = "shuffle"
+    lowering_strategy = "shuffle_1d"
       : (!pdl.operation) -> !pdl.operation
 }

diff  --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index cd933193299a9..2fbee42109706 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -57,6 +57,6 @@ transform.sequence failures(propagate) {
     : (!pdl.operation) -> !pdl.operation
 
   %func_8 = transform.vector.lower_transpose %func_7
-    lowering_strategy = "shuffle"
+    lowering_strategy = "shuffle_1d"
       : (!pdl.operation) -> !pdl.operation
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index a2fc23386c25d..44bba6fbed982 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -100,7 +100,7 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   transform.vector.lower_transpose %module_op
-    lowering_strategy = "shuffle"
+    lowering_strategy = "shuffle_1d"
       : (!pdl.operation) -> !pdl.operation
 }
 
@@ -609,3 +609,81 @@ transform.sequence failures(propagate) {
     avx2_lowering_strategy = true
       : (!pdl.operation) -> !pdl.operation
 }
+
+// -----
+
+func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16xf32> {
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+  return %0 : vector<16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.lower_transpose %module_op
+    lowering_strategy = "shuffle_16x16"
+      : (!pdl.operation) -> !pdl.operation
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
new file mode 100644
index 0000000000000..147d9a7f2ff4c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -convert-scf-to-cf \
+// RUN:   -test-transform-dialect-interpreter \
+// RUN:   -test-transform-dialect-erase-schedule \
+// RUN:   -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %in = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0], [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0], [48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0], [64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0], [80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0], [96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0], [112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0], [128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0], [144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0], [160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, 174.0, 175.0], [176.0, 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0], [192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0], [208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0], [224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, 239.0], [240.0, 241.0, 242.0, 243.0, 244.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0]]> : vector<16x16xf32>
+  %0 = vector.transpose %in, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+  vector.print %0 : vector<16x16xf32>
+  // CHECK:     ( ( 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240 ),
+  // CHECK-SAME:  ( 1, 17, 33, 49, 65, 81, 97, 113, 129, 145, 161, 177, 193, 209, 225, 241 ),
+  // CHECK-SAME:  ( 2, 18, 34, 50, 66, 82, 98, 114, 130, 146, 162, 178, 194, 210, 226, 242 ),
+  // CHECK-SAME:  ( 3, 19, 35, 51, 67, 83, 99, 115, 131, 147, 163, 179, 195, 211, 227, 243 ),
+  // CHECK-SAME:  ( 4, 20, 36, 52, 68, 84, 100, 116, 132, 148, 164, 180, 196, 212, 228, 244 ),
+  // CHECK-SAME:  ( 5, 21, 37, 53, 69, 85, 101, 117, 133, 149, 165, 181, 197, 213, 229, 245 ),
+  // CHECK-SAME:  ( 6, 22, 38, 54, 70, 86, 102, 118, 134, 150, 166, 182, 198, 214, 230, 246 ),
+  // CHECK-SAME:  ( 7, 23, 39, 55, 71, 87, 103, 119, 135, 151, 167, 183, 199, 215, 231, 247 ),
+  // CHECK-SAME:  ( 8, 24, 40, 56, 72, 88, 104, 120, 136, 152, 168, 184, 200, 216, 232, 248 ),
+  // CHECK-SAME:  ( 9, 25, 41, 57, 73, 89, 105, 121, 137, 153, 169, 185, 201, 217, 233, 249 ),
+  // CHECK-SAME:  ( 10, 26, 42, 58, 74, 90, 106, 122, 138, 154, 170, 186, 202, 218, 234, 250 ),
+  // CHECK-SAME:  ( 11, 27, 43, 59, 75, 91, 107, 123, 139, 155, 171, 187, 203, 219, 235, 251 ),
+  // CHECK-SAME:  ( 12, 28, 44, 60, 76, 92, 108, 124, 140, 156, 172, 188, 204, 220, 236, 252 ),
+  // CHECK-SAME:  ( 13, 29, 45, 61, 77, 93, 109, 125, 141, 157, 173, 189, 205, 221, 237, 253 ),
+  // CHECK-SAME:  ( 14, 30, 46, 62, 78, 94, 110, 126, 142, 158, 174, 190, 206, 222, 238, 254 ),
+  // CHECK-SAME:  ( 15, 31, 47, 63, 79, 95, 111, 127, 143, 159, 175, 191, 207, 223, 239, 255 ) )
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.lower_transpose %module_op
+    lowering_strategy = "shuffle_16x16"
+      : (!pdl.operation) -> !pdl.operation
+}
+


        


More information about the Mlir-commits mailing list