[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)
Diego Caballero
llvmlistbot at llvm.org
Mon Jun 30 22:54:39 PDT 2025
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/145740
>From 6aaa4e42756a10c95e59afaf2eba3e0b9583709c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 24 Jun 2025 06:33:22 +0000
Subject: [PATCH 1/2] [mlir][Vector] Add `vector.shuffle` tree transformation
This PR adds a new transformation that turns sequences of `vector.to_elements`
and `vector.from_elements` into a binary tree of `vector.shuffle` operations.
(Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779).
Example:
```
%0:4 = vector.to_elements %a : vector<4xf32>
%1:4 = vector.to_elements %b : vector<4xf32>
%2:4 = vector.to_elements %c : vector<4xf32>
%3 = vector.from_elements %0#0, %0#1, %0#2, %0#3,
%1#0, %1#1, %1#2, %1#3,
%2#0, %2#1, %2#2, %2#3 : vector<12xf32>
==>
%0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
%1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
%2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
```
The algorithm leverages the structured extraction/insertion information of
`vector.to_elements` and `vector.from_elements` operations and builds a set
of intervals to determine the vector length that should be used at each level
of the tree.
There are a few improvements that can be implemented in the future, such as
shuffle mask compression to avoid unnecessarily large vector lengths with poison
values, but I decided to keep things "simpler" and spend more time documenting the
different steps of the algorithm so that people can follow along.
---
.../Vector/Transforms/LoweringPatterns.h | 7 +
.../mlir/Dialect/Vector/Transforms/Passes.h | 1 +
.../mlir/Dialect/Vector/Transforms/Passes.td | 5 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
...LowerVectorToFromElementsToShuffleTree.cpp | 692 ++++++++++++++++++
...m-elements-to-shuffle-tree-transforms.mlir | 329 +++++++++
6 files changed, 1035 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
create mode 100644 mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+/// TODO: Implement mask compression.
+static SmallVector<int64_t> computePermutationShuffleMask(
+ ToElementsOp toElementOp0, const Interval &interval0,
+ ToElementsOp toElementOp1, const Interval &interval1,
+ FromElementsOp fromElementsOp, unsigned outputVectorSize) {
+ SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+ unsigned inputVectorSize =
+ toElementOp0.getSource().getType().getNumElements();
+
+ for (const auto &[inputIdx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
+ // Match `vector.from_elements` operands to the two input ops.
+ if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
+ continue;
+
+ // The permutation value for a particular operand is the ordinal position of
+ // the operand in the `vector.to_elements` list of results.
+ unsigned permVal = cast<OpResult>(element).getResultNumber();
+ unsigned maskIdx = inputIdx;
+
+ // The mask index is the ordinal position of the operand in
+ // `vector.from_elements` operand list. We make this position relative to
+ // the interval of the output vector resulting from combining the two
+ // input vectors.
+ if (currentToElemOp == toElementOp0) {
+ maskIdx -= interval0.first;
+ } else {
+ // currentToElemOp == toElementOp1
+ unsigned intervalOffset = interval1.first - interval0.first;
+ maskIdx += intervalOffset - interval1.first;
+ permVal += inputVectorSize;
+ }
+
+ mask[maskIdx] = permVal;
+ }
+
+ LLVM_DEBUG({
+ unsigned indLv = 1;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [";
+ llvm::interleaveComma(mask, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Combining: " << toElementOp0 << " and " << toElementOp1
+ << "\n";
+ });
+
+ return mask;
+}
+
+/// Compute the propagation shuffle mask for combining two intermediate shuffle
+/// operations of the tree. The propagation shuffle mask is the mapping of the
+/// intermediate vector elements, which have already been shuffled to their
+/// relative output position using the mask generated by
+/// `computePermutationShuffleMask`, to their next position in the tree.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+/// =>
+///
+/// // Level 1, vector length = 9
+/// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///
+/// TODO: Implement mask compression.
+///
+static SmallVector<int64_t> computePropagationShuffleMask(
+ ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
+ const Interval &rhsInterval, unsigned outputVectorSize) {
+ ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
+ ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
+ unsigned inputVectorSize = lhsShuffleMask.size();
+ assert(inputVectorSize == rhsShuffleMask.size() &&
+ "Expected both shuffle masks to have the same size");
+
+ unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
+ SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+
+ // Propagate any element from the input mask that is not poison. For the RHS
+ // input vector, the mask index is offset by the offset between the two
+ // intervals of the input vectors.
+ for (unsigned i = 0; i < inputVectorSize; ++i) {
+ if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
+ mask[i] = i;
+
+ unsigned rhsIdx = i + lhsRhsOffset;
+ if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
+ assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
+ assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set");
+ mask[rhsIdx] = i + inputVectorSize;
+ }
+ }
+
+ LLVM_DEBUG({
+ unsigned indLv = 1;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Propagation shuffle mask computation:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* LHS shuffle op: " << lhsShuffleOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* RHS shuffle op: " << rhsShuffleOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [";
+ llvm::interleaveComma(mask, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+
+ return mask;
+}
+
+/// Materialize the pre-computed shuffle tree configuration in the IR by
+/// generating the corresponding `vector.shuffle` ops.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// with the pre-computed shuffle tree configuration:
+///
+/// * Vector sizes per level: [8, 9]
+/// * Input intervals per level:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// =>
+///
+/// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// The code generation comprises of combining pairs of input vectors for each
+/// level of the tree, using the pre-computed per tree level intervals and
+/// vector sizes. The algorithm generates two kinds of shuffle masks:
+/// permutation masks and propagation masks. Permutation masks are computed for
+/// the first level of the tree and permute the input vector elements to their
+/// relative position in the final output. Propagation masks are computed for
+/// subsequent levels and propagate the elements to the next level without
+/// permutation. For further details on the shuffle mask computation, please,
+/// take a look at the corresponding `computePermutationShuffleMask` and
+/// `computePropagationShuffleMask` functions.
+///
+Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n");
+
+ // Initialize work list with the `vector.to_elements` sources.
+ SmallVector<Value> levelInputs;
+ llvm::transform(
+ toElementsDefs, std::back_inserter(levelInputs),
+ [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); });
+
+ // Build shuffle tree by combining pairs of vectors.
+ Location loc = fromElementsOp.getLoc();
+ unsigned currentLevel = 0;
+ for (const auto &[levelVectorSize, inputIntervals] :
+ llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << llvm::indent(1, kIndScale) << "* Processing level "
+ << currentLevel << " (vector size: " << levelVectorSize
+ << ", # inputs: " << levelInputs.size() << ")\n");
+
+ // Process level input vectors in pairs.
+ SmallVector<Value> levelOutputs;
+ for (size_t i = 0; i < levelInputs.size(); i += 2) {
+ Value lhsVector = levelInputs[i];
+ Value rhsVector = levelInputs[i + 1];
+ const Interval &lhsInterval = inputIntervals[i];
+ const Interval &rhsInterval = inputIntervals[i + 1];
+
+ // For the first level of the tree, permute the vector elements to their
+ // relative position in the final output. For subsequent levels, we
+ // propagate the elements to the next level without permutation.
+ SmallVector<int64_t> shuffleMask;
+ if (currentLevel == 0) {
+ shuffleMask = computePermutationShuffleMask(
+ toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval,
+ fromElementsOp, levelVectorSize);
+ } else {
+ auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
+ auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
+ shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
+ rhsShuffleOp, rhsInterval,
+ levelVectorSize);
+ }
+
+ Value shuffleVal = rewriter.create<vector::ShuffleOp>(
+ loc, lhsVector, rhsVector, shuffleMask);
+ levelOutputs.push_back(shuffleVal);
+ }
+
+ levelInputs = std::move(levelOutputs);
+ ++currentLevel;
+ }
+
+ assert(levelInputs.size() == 1 && "Should have exactly one result");
+ return levelInputs.front();
+}
+
+/// Gather and unique all the `vector.to_elements` operations that feed the
+/// `vector.from_elements` operation. The `vector.to_elements` operations are
+/// returned in order of appearance in the `vector.from_elements`'s operand
+/// list.
+static LogicalResult
+getToElementsDefiningOps(FromElementsOp fromElementsOp,
+ SmallVectorImpl<ToElementsOp> &toElementsDefs) {
+ SetVector<ToElementsOp> toElementsDefsSet;
+ for (Value element : fromElementsOp.getElements()) {
+ auto toElementsOp = element.getDefiningOp<ToElementsOp>();
+ if (!toElementsOp)
+ return failure();
+ toElementsDefsSet.insert(toElementsOp);
+ }
+
+ toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end());
+ return success();
+}
+
+/// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into
+/// a tree of `vector.shuffle` operations.
+struct ToFromElementsToShuffleTreeRewrite final
+ : OpRewritePattern<vector::FromElementsOp> {
+
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = fromElementsOp.getType();
+ if (resultType.getRank() != 1 || resultType.isScalable())
+ return failure();
+
+ SmallVector<ToElementsOp> toElementsDefs;
+ if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs)))
+ return failure();
+
+ // Avoid generating a shuffle tree for trivial `vector.to_elements` ->
+ // `vector.from_elements` forwarding cases that do not require shuffling.
+ if (toElementsDefs.size() == 1) {
+ ToElementsOp toElementsOp0 = toElementsDefs.front();
+ if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults()))
+ return failure();
+ }
+
+ VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs);
+ if (failed(shuffleTreeBuilder.computeShuffleTree()))
+ return failure();
+
+ Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
+ rewriter.replaceOp(fromElementsOp, finalShuffle);
+ return success();
+ }
+};
+
+struct LowerVectorToFromElementsToShuffleTreePass
+ : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase<
+ LowerVectorToFromElementsToShuffleTreePass> {
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToFromElementsToShuffleTreePatterns(patterns);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(),
+ benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
new file mode 100644
index 0000000000000..3dc579be12f0f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
@@ -0,0 +1,329 @@
+// RUN: mlir-opt -lower-vector-to-from-elements-to-shuffle-tree -split-input-file %s | FileCheck %s
+
+// Captured variable names for `vector.shuffle` operations follow the L#SH# convention,
+// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to
+// the shuffle index within that level.
+
+func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
+ %0:8 = vector.to_elements %a : vector<8xf32>
+ %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_single_input_shuffle(
+// CHECK-SAME: %[[A:.*]]: vector<8xf32>
+ // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32>
+ // CHECK: return %[[L0SH0]]
+
+// -----
+
+func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>,
+ %b: vector<8xf32>) -> vector<8xf32> {
+ %0:8 = vector.to_elements %a : vector<8xf32>
+ %1:8 = vector.to_elements %b : vector<8xf32>
+ %2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32>
+ return %2 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @from_elements_to_elements_single_shuffle(
+// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [7, 8, 6, 9, 5, 10, 4, 11] : vector<8xf32>
+// CHECK: return %[[L0SH0]]
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
+ %b: vector<8xf32>,
+ %c: vector<8xf32>,
+ %d: vector<8xf32>) -> vector<32xf32> {
+ %0:8 = vector.to_elements %a : vector<8xf32>
+ %1:8 = vector.to_elements %b : vector<8xf32>
+ %2:8 = vector.to_elements %c : vector<8xf32>
+ %3:8 = vector.to_elements %d : vector<8xf32>
+ %4 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7,
+ %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7,
+ %2#0, %2#1, %2#2, %2#3, %2#4, %2#5, %2#6, %2#7,
+ %3#0, %3#1, %3#2, %3#3, %3#4, %3#5, %3#6, %3#7 : vector<32xf32>
+ return %4 : vector<32xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32(
+// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: return %[[L1SH0]] : vector<32xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
+ %b: vector<4xf32>,
+ %c: vector<4xf32>) -> vector<12xf32> {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ %1:4 = vector.to_elements %b : vector<4xf32>
+ %2:4 = vector.to_elements %c : vector<4xf32>
+ %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32>
+ return %3 : vector<12xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+// CHECK: return %[[L1SH0]] : vector<12xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+ %a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>,
+ %e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>,
+ %i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>,
+ %m: vector<4xf32>, %n: vector<4xf32>, %o: vector<4xf32>, %p: vector<4xf32>,
+ %q: vector<4xf32>, %r: vector<4xf32>, %s: vector<4xf32>, %t: vector<4xf32>,
+ %u: vector<4xf32>, %v: vector<4xf32>, %w: vector<4xf32>, %x: vector<4xf32>,
+ %y: vector<4xf32>, %z: vector<4xf32>, %aa: vector<4xf32>, %ab: vector<4xf32>,
+ %ac: vector<4xf32>, %ad: vector<4xf32>, %ae: vector<4xf32>, %af: vector<4xf32>,
+ %ag: vector<4xf32>, %ah: vector<4xf32>, %ai: vector<4xf32>, %aj: vector<4xf32>,
+ %ak: vector<4xf32>, %al: vector<4xf32>, %am: vector<4xf32>, %an: vector<4xf32>,
+ %ao: vector<4xf32>, %ap: vector<4xf32>, %aq: vector<4xf32>, %ar: vector<4xf32>,
+ %as: vector<4xf32>, %at: vector<4xf32>, %au: vector<4xf32>, %av: vector<4xf32>,
+ %aw: vector<4xf32>, %ax: vector<4xf32>, %ay: vector<4xf32>, %az: vector<4xf32>,
+ %ba: vector<4xf32>, %bb: vector<4xf32>, %bc: vector<4xf32>, %bd: vector<4xf32>,
+ %be: vector<4xf32>, %bf: vector<4xf32>, %bg: vector<4xf32>, %bh: vector<4xf32>,
+ %bi: vector<4xf32>, %bj: vector<4xf32>, %bk: vector<4xf32>, %bl: vector<4xf32>) -> vector<256xf32> {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ %1:4 = vector.to_elements %b : vector<4xf32>
+ %2:4 = vector.to_elements %c : vector<4xf32>
+ %3:4 = vector.to_elements %d : vector<4xf32>
+ %4:4 = vector.to_elements %e : vector<4xf32>
+ %5:4 = vector.to_elements %f : vector<4xf32>
+ %6:4 = vector.to_elements %g : vector<4xf32>
+ %7:4 = vector.to_elements %h : vector<4xf32>
+ %8:4 = vector.to_elements %i : vector<4xf32>
+ %9:4 = vector.to_elements %j : vector<4xf32>
+ %10:4 = vector.to_elements %k : vector<4xf32>
+ %11:4 = vector.to_elements %l : vector<4xf32>
+ %12:4 = vector.to_elements %m : vector<4xf32>
+ %13:4 = vector.to_elements %n : vector<4xf32>
+ %14:4 = vector.to_elements %o : vector<4xf32>
+ %15:4 = vector.to_elements %p : vector<4xf32>
+ %16:4 = vector.to_elements %q : vector<4xf32>
+ %17:4 = vector.to_elements %r : vector<4xf32>
+ %18:4 = vector.to_elements %s : vector<4xf32>
+ %19:4 = vector.to_elements %t : vector<4xf32>
+ %20:4 = vector.to_elements %u : vector<4xf32>
+ %21:4 = vector.to_elements %v : vector<4xf32>
+ %22:4 = vector.to_elements %w : vector<4xf32>
+ %23:4 = vector.to_elements %x : vector<4xf32>
+ %24:4 = vector.to_elements %y : vector<4xf32>
+ %25:4 = vector.to_elements %z : vector<4xf32>
+ %26:4 = vector.to_elements %aa : vector<4xf32>
+ %27:4 = vector.to_elements %ab : vector<4xf32>
+ %28:4 = vector.to_elements %ac : vector<4xf32>
+ %29:4 = vector.to_elements %ad : vector<4xf32>
+ %30:4 = vector.to_elements %ae : vector<4xf32>
+ %31:4 = vector.to_elements %af : vector<4xf32>
+ %32:4 = vector.to_elements %ag : vector<4xf32>
+ %33:4 = vector.to_elements %ah : vector<4xf32>
+ %34:4 = vector.to_elements %ai : vector<4xf32>
+ %35:4 = vector.to_elements %aj : vector<4xf32>
+ %36:4 = vector.to_elements %ak : vector<4xf32>
+ %37:4 = vector.to_elements %al : vector<4xf32>
+ %38:4 = vector.to_elements %am : vector<4xf32>
+ %39:4 = vector.to_elements %an : vector<4xf32>
+ %40:4 = vector.to_elements %ao : vector<4xf32>
+ %41:4 = vector.to_elements %ap : vector<4xf32>
+ %42:4 = vector.to_elements %aq : vector<4xf32>
+ %43:4 = vector.to_elements %ar : vector<4xf32>
+ %44:4 = vector.to_elements %as : vector<4xf32>
+ %45:4 = vector.to_elements %at : vector<4xf32>
+ %46:4 = vector.to_elements %au : vector<4xf32>
+ %47:4 = vector.to_elements %av : vector<4xf32>
+ %48:4 = vector.to_elements %aw : vector<4xf32>
+ %49:4 = vector.to_elements %ax : vector<4xf32>
+ %50:4 = vector.to_elements %ay : vector<4xf32>
+ %51:4 = vector.to_elements %az : vector<4xf32>
+ %52:4 = vector.to_elements %ba : vector<4xf32>
+ %53:4 = vector.to_elements %bb : vector<4xf32>
+ %54:4 = vector.to_elements %bc : vector<4xf32>
+ %55:4 = vector.to_elements %bd : vector<4xf32>
+ %56:4 = vector.to_elements %be : vector<4xf32>
+ %57:4 = vector.to_elements %bf : vector<4xf32>
+ %58:4 = vector.to_elements %bg : vector<4xf32>
+ %59:4 = vector.to_elements %bh : vector<4xf32>
+ %60:4 = vector.to_elements %bi : vector<4xf32>
+ %61:4 = vector.to_elements %bj : vector<4xf32>
+ %62:4 = vector.to_elements %bk : vector<4xf32>
+ %63:4 = vector.to_elements %bl : vector<4xf32>
+ %64 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3, %3#0, %3#1, %3#2, %3#3, %4#0, %4#1, %4#2, %4#3,
+ %5#0, %5#1, %5#2, %5#3, %6#0, %6#1, %6#2, %6#3, %7#0, %7#1, %7#2, %7#3, %8#0, %8#1, %8#2, %8#3, %9#0, %9#1, %9#2, %9#3,
+ %10#0, %10#1, %10#2, %10#3, %11#0, %11#1, %11#2, %11#3, %12#0, %12#1, %12#2, %12#3, %13#0, %13#1, %13#2, %13#3, %14#0, %14#1, %14#2, %14#3,
+ %15#0, %15#1, %15#2, %15#3, %16#0, %16#1, %16#2, %16#3, %17#0, %17#1, %17#2, %17#3, %18#0, %18#1, %18#2, %18#3, %19#0, %19#1, %19#2, %19#3,
+ %20#0, %20#1, %20#2, %20#3, %21#0, %21#1, %21#2, %21#3, %22#0, %22#1, %22#2, %22#3, %23#0, %23#1, %23#2, %23#3, %24#0, %24#1, %24#2, %24#3,
+ %25#0, %25#1, %25#2, %25#3, %26#0, %26#1, %26#2, %26#3, %27#0, %27#1, %27#2, %27#3, %28#0, %28#1, %28#2, %28#3, %29#0, %29#1, %29#2, %29#3,
+ %30#0, %30#1, %30#2, %30#3, %31#0, %31#1, %31#2, %31#3, %32#0, %32#1, %32#2, %32#3, %33#0, %33#1, %33#2, %33#3, %34#0, %34#1, %34#2, %34#3,
+ %35#0, %35#1, %35#2, %35#3, %36#0, %36#1, %36#2, %36#3, %37#0, %37#1, %37#2, %37#3, %38#0, %38#1, %38#2, %38#3, %39#0, %39#1, %39#2, %39#3,
+ %40#0, %40#1, %40#2, %40#3, %41#0, %41#1, %41#2, %41#3, %42#0, %42#1, %42#2, %42#3, %43#0, %43#1, %43#2, %43#3, %44#0, %44#1, %44#2, %44#3,
+ %45#0, %45#1, %45#2, %45#3, %46#0, %46#1, %46#2, %46#3, %47#0, %47#1, %47#2, %47#3, %48#0, %48#1, %48#2, %48#3, %49#0, %49#1, %49#2, %49#3,
+ %50#0, %50#1, %50#2, %50#3, %51#0, %51#1, %51#2, %51#3, %52#0, %52#1, %52#2, %52#3, %53#0, %53#1, %53#2, %53#3, %54#0, %54#1, %54#2, %54#3,
+ %55#0, %55#1, %55#2, %55#3, %56#0, %56#1, %56#2, %56#3, %57#0, %57#1, %57#2, %57#3, %58#0, %58#1, %58#2, %58#3, %59#0, %59#1, %59#2, %59#3,
+ %60#0, %60#1, %60#2, %60#3, %61#0, %61#1, %61#2, %61#3, %62#0, %62#1, %62#2, %62#3, %63#0, %63#1, %63#2, %63#3 : vector<256xf32>
+ return %64 : vector<256xf32>
+}
+
+// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+// CHECK-SAME: %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>)
+// CHECK: %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH2:.+]] = vector.shuffle %[[E]], %[[F]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH3:.+]] = vector.shuffle %[[G]], %[[H]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH4:.+]] = vector.shuffle %[[I]], %[[J]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH5:.+]] = vector.shuffle %[[K]], %[[L]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH6:.+]] = vector.shuffle %[[M]], %[[N]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH7:.+]] = vector.shuffle %[[O]], %[[P]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH8:.+]] = vector.shuffle %[[Q]], %[[R]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH9:.+]] = vector.shuffle %[[S]], %[[T]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH10:.+]] = vector.shuffle %[[U]], %[[V]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH11:.+]] = vector.shuffle %[[W]], %[[X]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH12:.+]] = vector.shuffle %[[Y]], %[[Z]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH13:.+]] = vector.shuffle %[[AA]], %[[AB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH14:.+]] = vector.shuffle %[[AC]], %[[AD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH15:.+]] = vector.shuffle %[[AE]], %[[AF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH16:.+]] = vector.shuffle %[[AG]], %[[AH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH17:.+]] = vector.shuffle %[[AI]], %[[AJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH18:.+]] = vector.shuffle %[[AK]], %[[AL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH19:.+]] = vector.shuffle %[[AM]], %[[AN]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH20:.+]] = vector.shuffle %[[AO]], %[[AP]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH21:.+]] = vector.shuffle %[[AQ]], %[[AR]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH22:.+]] = vector.shuffle %[[AS]], %[[AT]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH23:.+]] = vector.shuffle %[[AU]], %[[AV]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH24:.+]] = vector.shuffle %[[AW]], %[[AX]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH25:.+]] = vector.shuffle %[[AY]], %[[AZ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH26:.+]] = vector.shuffle %[[BA]], %[[BB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH27:.+]] = vector.shuffle %[[BC]], %[[BD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH28:.+]] = vector.shuffle %[[BE]], %[[BF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH29:.+]] = vector.shuffle %[[BG]], %[[BH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH30:.+]] = vector.shuffle %[[BI]], %[[BJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH31:.+]] = vector.shuffle %[[BK]], %[[BL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L1SH0:.+]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH1:.+]] = vector.shuffle %[[L0SH2]], %[[L0SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH2:.+]] = vector.shuffle %[[L0SH4]], %[[L0SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH3:.+]] = vector.shuffle %[[L0SH6]], %[[L0SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH4:.+]] = vector.shuffle %[[L0SH8]], %[[L0SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH5:.+]] = vector.shuffle %[[L0SH10]], %[[L0SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH6:.+]] = vector.shuffle %[[L0SH12]], %[[L0SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH7:.+]] = vector.shuffle %[[L0SH14]], %[[L0SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH8:.+]] = vector.shuffle %[[L0SH16]], %[[L0SH17]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH9:.+]] = vector.shuffle %[[L0SH18]], %[[L0SH19]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH10:.+]] = vector.shuffle %[[L0SH20]], %[[L0SH21]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH11:.+]] = vector.shuffle %[[L0SH22]], %[[L0SH23]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH12:.+]] = vector.shuffle %[[L0SH24]], %[[L0SH25]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH13:.+]] = vector.shuffle %[[L0SH26]], %[[L0SH27]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH14:.+]] = vector.shuffle %[[L0SH28]], %[[L0SH29]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L1SH15:.+]] = vector.shuffle %[[L0SH30]], %[[L0SH31]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: %[[L2SH0:.+]] = vector.shuffle %[[L1SH0]], %[[L1SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH1:.+]] = vector.shuffle %[[L1SH2]], %[[L1SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH2:.+]] = vector.shuffle %[[L1SH4]], %[[L1SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH3:.+]] = vector.shuffle %[[L1SH6]], %[[L1SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH4:.+]] = vector.shuffle %[[L1SH8]], %[[L1SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH5:.+]] = vector.shuffle %[[L1SH10]], %[[L1SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH6:.+]] = vector.shuffle %[[L1SH12]], %[[L1SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L2SH7:.+]] = vector.shuffle %[[L1SH14]], %[[L1SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[L3SH0:.+]] = vector.shuffle %[[L2SH0]], %[[L2SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+// CHECK: %[[L3SH1:.+]] = vector.shuffle %[[L2SH2]], %[[L2SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+// CHECK: %[[L3SH2:.+]] = vector.shuffle %[[L2SH4]], %[[L2SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+// CHECK: %[[L3SH3:.+]] = vector.shuffle %[[L2SH6]], %[[L2SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+// CHECK: %[[L4SH0:.+]] = vector.shuffle %[[L3SH0]], %[[L3SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32>
+// CHECK: %[[L4SH1:.+]] = vector.shuffle %[[L3SH2]], %[[L3SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32>
+// CHECK: %[[L5SH0:.+]] = vector.shuffle %[[L4SH0]], %[[L4SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] : vector<128xf32>, vector<128xf32>
+// CHECK: return %[[L5SH0]] : vector<256xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
+ %b: vector<4xf32>,
+ %c: vector<4xf32>,
+ %d: vector<4xf32>) -> vector<16xf32> {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ %1:4 = vector.to_elements %b : vector<4xf32>
+ %2:4 = vector.to_elements %c : vector<4xf32>
+ %3:4 = vector.to_elements %d : vector<4xf32>
+ %4 = vector.from_elements %3#3, %0#0, %2#2, %1#1, %3#0, %2#1, %0#3, %1#2, %0#1, %3#2, %1#0, %2#3, %1#3, %0#2, %3#1, %2#0 : vector<16xf32>
+ return %4 : vector<16xf32>
+}
+
+// TODO: Implement mask compression to reduce the number of intermediate poison values.
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 15, 16, 4, 18, 6, 20, 8, 9, 23, 24, 25, 13, 14, 28] : vector<15xf32>, vector<15xf32>
+// CHECK: return %[[L1SH0]] : vector<16xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
+ %b: vector<4xf32>,
+ %c: vector<4xf32>) -> vector<12xf32> {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ %1:4 = vector.to_elements %b : vector<4xf32>
+ %2:4 = vector.to_elements %c : vector<4xf32>
+ %3 = vector.from_elements %0#2, %1#1, %2#0, %0#1, %1#0, %2#2, %0#0, %1#3, %2#3, %0#3, %1#2, %2#1 : vector<12xf32>
+ return %3 : vector<12xf32>
+}
+
+// TODO: Implement mask compression to reduce the number of intermediate poison values.
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 11, 3, 4, 14, 6, 7, 17, 9, 10, 20] : vector<11xf32>, vector<11xf32>
+// CHECK: return %[[L1SH0]] : vector<12xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
+ %b: vector<5xf32>,
+ %c: vector<5xf32>) -> vector<9xf32> {
+ %0:5 = vector.to_elements %a : vector<5xf32>
+ %1:5 = vector.to_elements %b : vector<5xf32>
+ %2:5 = vector.to_elements %c : vector<5xf32>
+ %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+ return %3 : vector<9xf32>
+}
+
+// TODO: Implement mask compression to reduce the number of intermediate poison values.
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(
+// CHECK-SAME: %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32>
+// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32>
+// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 8, 9, 4, 5, 6, 7, 14] : vector<8xf32>, vector<8xf32>
+// CHECK: return %[[L1SH0]] : vector<9xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
+ %b: vector<2xf32>,
+ %c: vector<2xf32>,
+ %d: vector<2xf32>) -> vector<32xf32> {
+ %0:2 = vector.to_elements %a : vector<2xf32>
+ %1:2 = vector.to_elements %b : vector<2xf32>
+ %2:2 = vector.to_elements %c : vector<2xf32>
+ %3:2 = vector.to_elements %d : vector<2xf32>
+ %4 = vector.from_elements %0#0, %0#0, %0#0, %0#0, %0#1, %0#1, %0#1, %0#1,
+ %1#0, %1#0, %1#0, %1#0, %1#1, %1#1, %1#1, %1#1,
+ %2#0, %2#0, %2#0, %2#0, %2#1, %2#1, %2#1, %2#1,
+ %3#0, %3#0, %3#0, %3#0, %3#1, %3#1, %3#1, %3#1 : vector<32xf32>
+ return %4 : vector<32xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(
+// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>
+ // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
+ // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
+ // CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+ // CHECK: return %[[L1SH0]] : vector<32xf32>
>From 810c85a3d3aaabf93adb0c40c6aef0d0aa52be4b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 1 Jul 2025 05:05:30 +0000
Subject: [PATCH 2/2] Feedback
---
...LowerVectorToFromElementsToShuffleTree.cpp | 161 ++++++++++--------
...m-elements-to-shuffle-tree-transforms.mlir | 103 +++++++++--
2 files changed, 181 insertions(+), 83 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 53728d6dbe2a3..504103529cdcb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -124,7 +124,6 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
///
/// TODO: Implement mask compression to reduce the number of intermediate poison
/// values.
-///
class VectorShuffleTreeBuilder {
public:
VectorShuffleTreeBuilder() = delete;
@@ -142,8 +141,8 @@ class VectorShuffleTreeBuilder {
private:
// IR input information.
- FromElementsOp fromElementsOp;
- SmallVector<ToElementsOp> toElementsDefs;
+ FromElementsOp fromElemsOp;
+ SmallVector<ToElementsOp> toElemsDefs;
// Shuffle tree configuration.
unsigned numLevels;
@@ -162,16 +161,19 @@ class VectorShuffleTreeBuilder {
VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
- : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
-
- assert(fromElementsOp && "from_elements op is required");
- assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+ : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
+ assert(fromElemsOp && "from_elements op is required");
+ assert(!toElemsDefs.empty() && "At least one to_elements op is required");
+}
- // Duplicate the last vector if the number of `vector.to_elements` is odd to
- // simplify the shuffle tree algorithm.
- if (toElementsDefs.size() % 2 != 0) {
- toElementsDefs.push_back(toElementsDefs.back());
- }
+/// Duplicate the last operation, value or interval if the total number of them
+/// is odd. This is useful to simplify the shuffle tree algorithm given that
+/// vectors are shuffled in pairs and duplication would lead to the last shuffle
+/// to have a single (duplicated) input vector.
+template <typename T>
+static void duplicateLastIfOdd(SmallVectorImpl<T> &values) {
+ if (values.size() % 2 != 0)
+ values.push_back(values.back());
}
// ===--------------------------------------------------------------------===//
@@ -207,20 +209,20 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
// Map `vector.to_elements` ops to their ordinal position in the
// `vector.from_elements` operand list. Make sure duplicated
// `vector.to_elements` ops are mapped to the its first occurrence.
- DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
- for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
- toElementsToInputOrdinal.insert({toElementsOp, idx});
+ DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal;
+ for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs))
+ toElemsToInputOrdinal.insert({toElemsOp, idx});
// Compute intervals for each input vector in the shuffle tree. The first
// level computation is special-cased to keep the implementation simpler.
- SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(),
{kMaxUnsigned, kMaxUnsigned});
for (const auto &[idx, element] :
- llvm::enumerate(fromElementsOp.getElements())) {
- auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
- unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ llvm::enumerate(fromElemsOp.getElements())) {
+ auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
Interval ¤tInterval = firstLevelIntervals[inputIdx];
// Set lower bound to the first occurrence of the `vector.to_elements`.
@@ -231,19 +233,13 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
currentInterval.second = idx;
}
- // If the number of `vector.to_elements` is odd and the last op was
- // duplicated, the interval for the duplicated op was not computed in the
- // previous step as all the input occurrences were mapped to the original op.
- // We copy the interval of the original op to the interval of the duplicated
- // op manually.
- if (firstLevelIntervals.back().second == kMaxUnsigned)
- firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
-
+ duplicateLastIfOdd(toElemsDefs);
+ duplicateLastIfOdd(firstLevelIntervals);
inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
// Compute intervals for the remaining levels.
unsigned outputNumElements =
- cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
for (unsigned level = 1; level < numLevels; ++level) {
const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
SmallVector<Interval> currentLevelIntervals(
@@ -265,6 +261,7 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
outputNumElements - 1);
}
+ duplicateLastIfOdd(currentLevelIntervals);
inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
}
}
@@ -311,9 +308,9 @@ void VectorShuffleTreeBuilder::dump() {
++indLv;
llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
++indLv;
- for (const auto &toElementsOp : toElementsDefs)
- llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
- llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ for (const auto &toElemsOp : toElemsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n";
--indLv;
llvm::dbgs() << llvm::indent(indLv, kIndScale)
@@ -366,9 +363,7 @@ void VectorShuffleTreeBuilder::dump() {
/// corresponding utility functions.
LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
// Initialize shuffle tree information based on its size.
- assert(toElementsDefs.size() > 1 &&
- "At least two 'vector.to_elements' ops are required");
- numLevels = llvm::Log2_64(toElementsDefs.size());
+ numLevels = std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
vectorSizePerLevel.resize(numLevels, 0);
inputIntervalsPerLevel.reserve(numLevels);
@@ -402,17 +397,18 @@ LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
///
-/// TODO: Implement mask compression.
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
static SmallVector<int64_t> computePermutationShuffleMask(
ToElementsOp toElementOp0, const Interval &interval0,
ToElementsOp toElementOp1, const Interval &interval1,
- FromElementsOp fromElementsOp, unsigned outputVectorSize) {
+ FromElementsOp fromElemsOp, unsigned outputVectorSize) {
SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
unsigned inputVectorSize =
toElementOp0.getSource().getType().getNumElements();
for (const auto &[inputIdx, element] :
- llvm::enumerate(fromElementsOp.getElements())) {
+ llvm::enumerate(fromElemsOp.getElements())) {
auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
// Match `vector.from_elements` operands to the two input ops.
if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
@@ -476,8 +472,8 @@ static SmallVector<int64_t> computePermutationShuffleMask(
/// // Level 1, vector length = 9
/// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
///
-/// TODO: Implement mask compression.
-///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
static SmallVector<int64_t> computePropagationShuffleMask(
ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
const Interval &rhsInterval, unsigned outputVectorSize) {
@@ -487,6 +483,7 @@ static SmallVector<int64_t> computePropagationShuffleMask(
assert(inputVectorSize == rhsShuffleMask.size() &&
"Expected both shuffle masks to have the same size");
+ bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
@@ -497,6 +494,9 @@ static SmallVector<int64_t> computePropagationShuffleMask(
if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
mask[i] = i;
+ if (hasSameInput)
+ continue;
+
unsigned rhsIdx = i + lhsRhsOffset;
if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
@@ -565,15 +565,19 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
// Initialize work list with the `vector.to_elements` sources.
SmallVector<Value> levelInputs;
- llvm::transform(
- toElementsDefs, std::back_inserter(levelInputs),
- [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); });
+ llvm::transform(toElemsDefs, std::back_inserter(levelInputs),
+ [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); });
+ // TODO: Check that every pair of input has the same vector size. Otherwise,
+ // promote the narrower one to the wider one.
// Build shuffle tree by combining pairs of vectors.
- Location loc = fromElementsOp.getLoc();
+ Location loc = fromElemsOp.getLoc();
unsigned currentLevel = 0;
for (const auto &[levelVectorSize, inputIntervals] :
llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) {
+
+ duplicateLastIfOdd(levelInputs);
+
LLVM_DEBUG(llvm::dbgs()
<< llvm::indent(1, kIndScale) << "* Processing level "
<< currentLevel << " (vector size: " << levelVectorSize
@@ -593,8 +597,8 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
SmallVector<int64_t> shuffleMask;
if (currentLevel == 0) {
shuffleMask = computePermutationShuffleMask(
- toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval,
- fromElementsOp, levelVectorSize);
+ toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval,
+ fromElemsOp, levelVectorSize);
} else {
auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
@@ -621,17 +625,17 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
/// returned in order of appearance in the `vector.from_elements`'s operand
/// list.
static LogicalResult
-getToElementsDefiningOps(FromElementsOp fromElementsOp,
- SmallVectorImpl<ToElementsOp> &toElementsDefs) {
- SetVector<ToElementsOp> toElementsDefsSet;
- for (Value element : fromElementsOp.getElements()) {
- auto toElementsOp = element.getDefiningOp<ToElementsOp>();
- if (!toElementsOp)
+getToElementsDefiningOps(FromElementsOp fromElemsOp,
+ SmallVectorImpl<ToElementsOp> &toElemsDefs) {
+ SetVector<ToElementsOp> toElemsDefsSet;
+ for (Value element : fromElemsOp.getElements()) {
+ auto toElemsOp = element.getDefiningOp<ToElementsOp>();
+ if (!toElemsOp)
return failure();
- toElementsDefsSet.insert(toElementsOp);
+ toElemsDefsSet.insert(toElemsOp);
}
- toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end());
+ toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end());
return success();
}
@@ -642,30 +646,53 @@ struct ToFromElementsToShuffleTreeRewrite final
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp,
+ LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
PatternRewriter &rewriter) const override {
- VectorType resultType = fromElementsOp.getType();
- if (resultType.getRank() != 1 || resultType.isScalable())
- return failure();
+ VectorType resultType = fromElemsOp.getType();
+ if (resultType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ fromElemsOp, "Multi-dimensional vectors are not supported yet");
+ if (resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ fromElemsOp,
+ "'vector.from_elements' does not support scalable vectors");
+
+ SmallVector<ToElementsOp> toElemsDefs;
+ if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs)))
+ return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources");
+
+ int64_t numElements =
+ toElemsDefs.front().getSource().getType().getNumElements();
+ for (ToElementsOp toElemsOp : toElemsDefs) {
+ if (toElemsOp.getSource().getType().getNumElements() != numElements)
+ return rewriter.notifyMatchFailure(
+ fromElemsOp, "unsupported sources with different vector sizes");
+ }
- SmallVector<ToElementsOp> toElementsDefs;
- if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs)))
- return failure();
+ if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
+ return !toElemsOp.getSource().getType().hasRank();
+ })) {
+ return rewriter.notifyMatchFailure(fromElemsOp,
+ "0-D vectors are not supported");
+ }
// Avoid generating a shuffle tree for trivial `vector.to_elements` ->
// `vector.from_elements` forwarding cases that do not require shuffling.
- if (toElementsDefs.size() == 1) {
- ToElementsOp toElementsOp0 = toElementsDefs.front();
- if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults()))
- return failure();
+ if (toElemsDefs.size() == 1) {
+ ToElementsOp toElemsOp0 = toElemsDefs.front();
+ if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) {
+ return rewriter.notifyMatchFailure(
+ fromElemsOp, "trivial forwarding case does not require shuffling");
+ }
}
- VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs);
+ VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs);
if (failed(shuffleTreeBuilder.computeShuffleTree()))
- return failure();
+ return rewriter.notifyMatchFailure(fromElemsOp,
+ "failed to compute shuffle tree");
Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
- rewriter.replaceOp(fromElementsOp, finalShuffle);
+ rewriter.replaceOp(fromElemsOp, finalShuffle);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
index 3dc579be12f0f..a8d3d5278e893 100644
--- a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
@@ -4,13 +4,27 @@
// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to
// the shuffle index within that level.
-func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
+func.func @trivial_forwarding(%a: vector<8xf32>) -> vector<8xf32> {
+ %0:8 = vector.to_elements %a : vector<8xf32>
+ %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// No shuffle tree needed for trivial forwarding case.
+
+// CHECK-LABEL: func @trivial_forwarding(
+// CHECK-SAME: %[[A:.*]]: vector<8xf32>
+// CHECK: return %[[A]] : vector<8xf32>
+
+// -----
+
+func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
%0:8 = vector.to_elements %a : vector<8xf32>
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
return %1 : vector<8xf32>
}
-// CHECK-LABEL: func @to_from_elements_single_input_shuffle(
+// CHECK-LABEL: func @single_input_shuffle(
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32>
// CHECK: return %[[L0SH0]]
@@ -32,7 +46,7 @@ func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>,
// -----
-func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
+func.func @shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
%b: vector<8xf32>,
%c: vector<8xf32>,
%d: vector<8xf32>) -> vector<32xf32> {
@@ -47,7 +61,7 @@ func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
return %4 : vector<32xf32>
}
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32(
+// CHECK-LABEL: func @shuffle_tree_concat_4x8_to_32(
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
@@ -56,7 +70,7 @@ func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
// -----
-func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
+func.func @shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
%b: vector<4xf32>,
%c: vector<4xf32>) -> vector<12xf32> {
%0:4 = vector.to_elements %a : vector<4xf32>
@@ -66,7 +80,7 @@ func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
return %3 : vector<12xf32>
}
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12(
+// CHECK-LABEL: func @shuffle_tree_concat_3x4_to_12(
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
@@ -75,7 +89,7 @@ func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
// -----
-func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+func.func @shuffle_tree_concat_64x4_256(
%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>,
%e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>,
%i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>,
@@ -172,7 +186,7 @@ func.func @to_from_elements_shuffle_tree_concat_64x4_256(
return %64 : vector<256xf32>
}
-// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+// CHECK-LABEL: func.func @shuffle_tree_concat_64x4_256(
// CHECK-SAME: %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>)
// CHECK: %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
// CHECK: %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
@@ -241,7 +255,7 @@ func.func @to_from_elements_shuffle_tree_concat_64x4_256(
// -----
-func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
+func.func @shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
%b: vector<4xf32>,
%c: vector<4xf32>,
%d: vector<4xf32>) -> vector<16xf32> {
@@ -255,7 +269,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
// TODO: Implement mask compression to reduce the number of intermediate poison values.
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(
+// CHECK-LABEL: func @shuffle_tree_arbitrary_4x4_to_16(
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32>
// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32>
@@ -264,7 +278,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
// -----
-func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
+func.func @shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
%b: vector<4xf32>,
%c: vector<4xf32>) -> vector<12xf32> {
%0:4 = vector.to_elements %a : vector<4xf32>
@@ -276,7 +290,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
// TODO: Implement mask compression to reduce the number of intermediate poison values.
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(
+// CHECK-LABEL: func @shuffle_tree_arbitrary_3x4_to_12(
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32>
// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32>
@@ -285,7 +299,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
// -----
-func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
+func.func @shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
%b: vector<5xf32>,
%c: vector<5xf32>) -> vector<9xf32> {
%0:5 = vector.to_elements %a : vector<5xf32>
@@ -297,7 +311,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
// TODO: Implement mask compression to reduce the number of intermediate poison values.
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(
+// CHECK-LABEL: func @shuffle_tree_arbitrary_3x5_to_9(
// CHECK-SAME: %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32>
// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32>
@@ -306,7 +320,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
// -----
-func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
+func.func @shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
%b: vector<2xf32>,
%c: vector<2xf32>,
%d: vector<2xf32>) -> vector<32xf32> {
@@ -321,9 +335,66 @@ func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
return %4 : vector<32xf32>
}
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(
+// CHECK-LABEL: func @shuffle_tree_broadcast_4x2_to_32(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: return %[[L1SH0]] : vector<32xf32>
+
+// -----
+
+
+func.func @shuffle_tree_arbitrary_mixed_sizes(
+ %a : vector<2xf32>,
+ %b : vector<1xf32>,
+ %c : vector<3xf32>,
+ %d : vector<1xf32>,
+ %e : vector<5xf32>) -> vector<6xf32> {
+ %0:2 = vector.to_elements %a : vector<2xf32>
+ %1 = vector.to_elements %b : vector<1xf32>
+ %2:3 = vector.to_elements %c : vector<3xf32>
+ %3 = vector.to_elements %d : vector<1xf32>
+ %4:5 = vector.to_elements %e : vector<5xf32>
+ %5 = vector.from_elements %0#0, %2#0, %3, %4#0, %1, %4#3 : vector<6xf32>
+ return %5 : vector<6xf32>
+}
+
+// TODO: Support mixed vector sizes.
+
+// CHECK-LABEL: func @shuffle_tree_arbitrary_mixed_sizes(
+// CHECK-COUNT-5: vector.to_elements
+// CHECK: vector.from_elements
+
+// -----
+
+func.func @shuffle_tree_odd_intermediate_vectors(
+ %a : vector<2xf32>,
+ %b : vector<2xf32>,
+ %c : vector<2xf32>,
+ %d : vector<2xf32>,
+ %e : vector<2xf32>,
+ %f : vector<2xf32>) -> vector<6xf32> {
+ %0:2 = vector.to_elements %a : vector<2xf32>
+ %1:2 = vector.to_elements %b : vector<2xf32>
+ %2:2 = vector.to_elements %c : vector<2xf32>
+ %3:2 = vector.to_elements %d : vector<2xf32>
+ %4:2 = vector.to_elements %e : vector<2xf32>
+ %5:2 = vector.to_elements %f : vector<2xf32>
+ %6 = vector.from_elements %0#0, %1#1, %2#0, %3#1, %4#0, %5#1 : vector<6xf32>
+ return %6 : vector<6xf32>
+}
+
+// CHECK-LABEL: func @shuffle_tree_odd_intermediate_vectors(
+// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>, %[[E:.*]]: vector<2xf32>, %[[F:.*]]: vector<2xf32>
+// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 3] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 3] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[L0SH2:.*]] = vector.shuffle %[[E]], %[[F]] [0, 3] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[L2SH0:.*]] = vector.shuffle %[[L0SH2]], %[[L0SH2]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[L3SH0:.*]] = vector.shuffle %[[L1SH0]], %[[L2SH0]] [0, 1, 2, 3, 4, 5] : vector<4xf32>, vector<4xf32>
+// CHECK: return %[[L3SH0]] : vector<6xf32>
+
+
+
+
More information about the Mlir-commits
mailing list