[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)

Diego Caballero llvmlistbot at llvm.org
Wed Jul 2 11:38:52 PDT 2025


================
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+///   1. Vectors are shuffled in pairs by order of appearance in
+///      the `vector.from_elements` operand list.
+///   2. Each input vector to each level is used only once.
+///   3. The number of levels in the tree is:
+///        ceil(log2(# `vector.to_elements` ops)).
+///   4. Vectors at each level of the tree have the same vector length.
+///   5. Vector positions that do not need to be shuffled are represented with
+///      poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+///   %0:4 = vector.to_elements %a : vector<4xf32>
+///   %1:4 = vector.to_elements %b : vector<4xf32>
+///   %2:4 = vector.to_elements %c : vector<4xf32>
+///   %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+///                             %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+///                               : vector<12xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+///     : vector<4xf32>, vector<4xf32>
+///   %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+///     : vector<4xf32>, vector<4xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+///                                                  6, 7, 8, 9, 10, 11]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * The shuffle tree has two levels:
+///         - Level 1 = (%shuffle0, %shuffle1)
+///         - Level 2 = (%result)
+///     * `%a` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%0#0` and `%1#0`).
+///     * `%c` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 16,
+///       respectively.
+///     * `%shuffle1` uses poison values to match the vector length of its
+///       tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+///     : vector<5xf32>, vector<5xf32>
+///   %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+///     : vector<5xf32>, vector<5xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * `%c` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%2#2` and `%1#1`).
+///     * `%a` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 9,
+///       respectively.
+///     * `%shuffle0` uses poison values to mark unused vector positions and
+///       match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+  VectorShuffleTreeBuilder() = delete;
+  VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+                           ArrayRef<ToElementsOp> toElemDefs);
+
+  /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+  /// and compute the shuffle tree configuration. This method does not generate
+  /// any IR.
+  LogicalResult computeShuffleTree();
+
+  /// Materialize the shuffle tree configuration computed by
+  /// `computeShuffleTree` in the IR.
+  Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+  // IR input information.
+  FromElementsOp fromElementsOp;
+  SmallVector<ToElementsOp> toElementsDefs;
+
+  // Shuffle tree configuration.
+  unsigned numLevels;
+  SmallVector<unsigned> vectorSizePerLevel;
+  /// Holds the range of positions in the final output that each vector input
+  /// in the tree is contributing to.
+  SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+  // Utility methods to compute the shuffle tree configuration.
+  void computeInputVectorIntervals();
+  void computeOutputVectorSizePerLevel();
+
+  /// Dump the shuffle tree configuration.
+  void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+    FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+    : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+  assert(fromElementsOp && "from_elements op is required");
+  assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+  // Duplicate the last vector if the number of `vector.to_elements` is odd to
+  // simplify the shuffle tree algorithm.
+  if (toElementsDefs.size() % 2 != 0) {
+    toElementsDefs.push_back(toElementsDefs.back());
+  }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+///    * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+///    * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+///    * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+///    * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
----------------
dcaballe wrote:

The intervals model what an input vector contributes to the final output but they belong to the input vectors to a specific level. They are used to produce an output interval, which will be an input interval for the next level. That is:

```
            input vector 0             input vector 1        input vector 2              input vector 3
                  |                          |                     |                           |

Iteration 0: 
          input interval 0           input interval 1       input interval 2           input interval 3 
                          \         /                                       \         /
------------------------------------------- tree level 0 --------------------------------------- 
                               |                                                 | 
                        output interval 0-1                          output interval 2-3 
                               |                                             |
Iteration 1:      
                               |                                             |
                         input interval 0-1                          input interval 2-3 
                                           \                        /
------------------------------------------- tree level 1 ---------------------------------------
                                                     | 
                                           output interval 0-1-2-3

```

I think if we use `output` in the name of what is called "input interval" right now is going to be confusing.
Actually, I can add this amazing ASCII to the doc and elaborate a bit more if that helps! :)

https://github.com/llvm/llvm-project/pull/145740


More information about the Mlir-commits mailing list