[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)
Diego Caballero
llvmlistbot at llvm.org
Wed Jun 25 17:47:43 PDT 2025
================
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
----------------
dcaballe wrote:
This check refers to the number of `vector.to_elements` inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
https://github.com/llvm/llvm-project/pull/145740
More information about the Mlir-commits
mailing list