[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Jul 8 08:55:50 PDT 2025
================
@@ -0,0 +1,747 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each vector at each level is used only once.
+/// 3. The number of levels in the tree is:
+/// 1 (input vectors) + ceil(max(1,log2(# `vector.to_elements` ops))).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has three levels:
+/// - Level 0 = (%a, %b, %c, %c)
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for level 1 and level 2 are 8 and 16, respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for level 1 and level 2 are 8 and 9, respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElemsOp;
+ SmallVector<ToElementsOp> toElemsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions each vector in the tree contributes to the
+ /// final output vector.
+ SmallVector<SmallVector<Interval>> intervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeShuffleTreeIntervals();
+ void computeShuffleTreeVectorSizes();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
+ assert(fromElemsOp && "from_elements op is required");
+ assert(!toElemsDefs.empty() && "At least one to_elements op is required");
+}
+
+/// Duplicate the last operation, value or interval if the total number of them
+/// is odd. This is useful to simplify the shuffle tree algorithm given that
+/// vectors are shuffled in pairs and duplication would lead to the last shuffle
+/// to have a single (duplicated) input vector.
+template <typename T>
+static void duplicateLastIfOdd(SmallVectorImpl<T> &values) {
+ if (values.size() % 2 != 0)
+ values.push_back(values.back());
+}
+
+// ===---------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===---------------------------------------------------------------------===//
+
+/// Compute the intervals for all the vectors in the shuffle tree. The interval
+/// interval of a vector is the range of positions that the vector contributes
+/// to the final output vector.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// The shuffle tree has 3 levels. Level 0 has 4 vectors (%2, %1, %0, %0, the
+/// last one is duplicated to make the number of inputs even) so we compute the
+/// interval for each vector:
+///
+/// * intervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * intervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * intervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * intervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0
+/// so we compute the intervals for each vector at level 1 as:
+/// * intervalsPerLevel[1][0] = intervalsPerLevel[0][0] U
+/// intervalsPerLevel[0][1] = [0,7]
+/// * intervalsPerLevel[1][1] = intervalsPerLevel[0][2] U
+/// intervalsPerLevel[0][3] = [2,8]
+///
+/// Level 2 is the last level and only contains the output vector so the
+/// interval should be the whole output vector:
+/// * intervalsPerLevel[2][0] = intervalsPerLevel[1][0] U
+/// intervalsPerLevel[1][1] = [0,8]
+///
+void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal;
+ for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs))
+ toElemsToInputOrdinal.insert({toElemsOp, idx});
+
+ // Compute intervals for each vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElemsOp.getElements())) {
+ auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ duplicateLastIfOdd(toElemsDefs);
+ duplicateLastIfOdd(firstLevelIntervals);
+ intervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ bool isLastLevel = level == numLevels - 1;
+ const auto &prevLevelIntervals = intervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ size_t currentNumLevels = currentLevelIntervals.size();
+ for (size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two vectors from the previous level being shuffled at
+ // this level.
+ interval.first = prevLhsInterval.first;
+ interval.second =
+ std::max(prevLhsInterval.second, prevRhsInterval.second);
+ }
+
+ // Duplicate the last interval if the number of intervals is odd, except for
+ // the last level as it only contains the output vector, which doesn't have
+ // to be shuffled.
+ if (!isLastLevel)
+ duplicateLastIfOdd(currentLevelIntervals);
+
+ intervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform vector size for each level of the shuffle tree, given
+/// the intervals of the vectors at each level. The vector size of a level is
+/// the size of the widest interval at that level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+/// * Level 2: [0,8]
+///
+/// Vector sizes:
+/// * Level 0: Arbitrary sizes from input vectors.
+/// * Level 1: max(size_of([0,7]) = 8, size_of([2,8]) = 7) = 8
+/// * Level 2: max(size_of([0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() {
+ // Compute vector size for each level. There are two direct cases:
+ // * First level: the vector size depends on the actual size of the input
+ // vectors and it's allowed to be non-uniform. We set it to 0.
+ // * Last level: the vector size is the output vector size so it doesn't
+ // have to be computed using intervals.
+ vectorSizePerLevel.front() = 0;
+ vectorSizePerLevel.back() =
+ cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
+
+ for (unsigned level = 1; level < numLevels - 1; ++level) {
+ const auto ¤tLevelIntervals = intervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ size_t numIntervals = currentLevelIntervals.size();
+ for (size_t i = 0; i < numIntervals; ++i) {
+ const auto &interval = currentLevelIntervals[i];
+ unsigned intervalSize = interval.second - interval.first + 1;
+ currentVectorSize = std::max(currentVectorSize, intervalSize);
+ }
+ assert(currentVectorSize > 0 && "vector size must be positive");
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElemsOp : toElemsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: ";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The actual representation of the shuffle tree configuration is based on
+/// intervals of each vector at each level of the shuffle tree (i.e., %2, %1,
+/// %0, %0, %2_1, %0_0 and %2_1_0_0) and the vector size for each level. For
+/// further details on intervals and vector size computation, please, take a
+/// look at the corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size. For the number of
+ // levels, we add one to account for the input `vector.to_elements` as one
+ // tree level. We need the std::max(1) to account for a single element input.
+ numLevels = 1u + std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
+ vectorSizePerLevel.resize(numLevels, 0);
+ intervalsPerLevel.reserve(numLevels);
+
+ computeShuffleTreeIntervals();
+ computeShuffleTreeVectorSizes();
+ dump();
+
+ return success();
+}
+
+// ===---------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===---------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the vector elements to their
+/// final position in the output vector, relative to the intermediate output
+/// vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 1, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+static SmallVector<int64_t> computePermutationShuffleMask(
+ ToElementsOp toElementOp0, const Interval &interval0,
+ ToElementsOp toElementOp1, const Interval &interval1,
+ FromElementsOp fromElemsOp, unsigned outputVectorSize) {
+ SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+ unsigned inputVectorSize =
+ toElementOp0.getSource().getType().getNumElements();
+
+ for (const auto &[inputIdx, element] :
+ llvm::enumerate(fromElemsOp.getElements())) {
+ auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
+ // Match `vector.from_elements` operands to the two input ops.
+ if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
+ continue;
+
+ // The permutation value for a particular operand is the ordinal position of
+ // the operand in the `vector.to_elements` list of results.
+ unsigned permVal = cast<OpResult>(element).getResultNumber();
+ unsigned maskIdx = inputIdx;
+
+ // The mask index is the ordinal position of the operand in
+ // `vector.from_elements` operand list. We make this position relative to
+ // the output interval resulting from combining the two input intervals.
+ if (currentToElemOp == toElementOp0) {
+ maskIdx -= interval0.first;
+ } else {
+ // currentToElemOp == toElementOp1
+ unsigned intervalOffset = interval1.first - interval0.first;
+ maskIdx += intervalOffset - interval1.first;
+ permVal += inputVectorSize;
+ }
+
+ mask[maskIdx] = permVal;
+ }
+
+ LLVM_DEBUG({
+ unsigned indLv = 1;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [";
+ llvm::interleaveComma(mask, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Combining: " << toElementOp0 << " and " << toElementOp1
+ << "\n";
+ });
+
+ return mask;
+}
+
+/// Compute the propagation shuffle mask for combining two intermediate shuffle
+/// operations of the tree. The propagation shuffle mask is the mapping of the
+/// intermediate vector elements, which have already been shuffled to their
+/// relative output position using the mask generated by
+/// `computePermutationShuffleMask`, to their next position in the tree.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// // Level 1, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+/// =>
+///
+/// // Level 2, vector length = 9
+/// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+static SmallVector<int64_t> computePropagationShuffleMask(
+ ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
+ const Interval &rhsInterval, unsigned outputVectorSize) {
+ ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
+ ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
+ unsigned inputVectorSize = lhsShuffleMask.size();
+ assert(inputVectorSize == rhsShuffleMask.size() &&
+ "Expected both shuffle masks to have the same size");
+
+ bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
+ unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
+ SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+
+ // Propagate any element from the input mask that is not poison. For the RHS
+ // vector, offset mask index by the distance between the intervals.
+ for (unsigned i = 0; i < inputVectorSize; ++i) {
+ if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
+ mask[i] = i;
+
+ if (hasSameInput)
+ continue;
+
+ unsigned rhsIdx = i + lhsRhsOffset;
+ if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
+ assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
+ assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set");
+ mask[rhsIdx] = i + inputVectorSize;
+ }
+ }
+
+ LLVM_DEBUG({
+ unsigned indLv = 1;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Propagation shuffle mask computation:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* LHS shuffle op: " << lhsShuffleOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* RHS shuffle op: " << rhsShuffleOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [";
+ llvm::interleaveComma(mask, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+
+ return mask;
+}
+
+/// Materialize the pre-computed shuffle tree configuration in the IR by
+/// generating the corresponding `vector.shuffle` ops.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// with the pre-computed shuffle tree configuration:
+///
+/// * Vector sizes per level: 0, 8, 9
+/// * Input intervals per level:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+/// * Level 2: [0,8]
+///
+/// =>
+///
+/// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// The code generation consists of combining pairs of vectors at each level of
+/// the tree, using the pre-computed tree intervals and vector sizes. The
+/// algorithm generates two kinds of shuffle masks: permutation masks and
+/// permutation masks and propagation masks:
----------------
banach-space wrote:
```suggestion
/// algorithm generates two kinds of shuffle masks: permutation masks and
/// propagation masks:
```
https://github.com/llvm/llvm-project/pull/145740
More information about the Mlir-commits
mailing list