[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri May 30 09:52:37 PDT 2025
================
@@ -21,177 +21,258 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include <numeric>
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
-using namespace mlir::vector;
-
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
- int step = 1) {
- for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
- indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
- break;
- indices[dim] = 0;
- step = 1;
+/// Perform the inplace update
+/// rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+ MutableArrayRef<int64_t> rhs) {
+
+ // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+ for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+ int64_t dimBase = base[dim];
+ assert(rhs[dim] < dimBase && "rhs not in base");
+
+ int64_t incremented = rhs[dim] + lhs;
+
+ // If the incremented value excedes the dimension base, we must spill to the
+ // next most significant dimension and repeat (we might need to spill to
+ // more significant dimensions multiple times).
+ lhs = incremented / dimBase;
+ rhs[dim] = incremented % dimBase;
+ if (lhs == 0)
+ break;
}
}
namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
+
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
+///
+/// vector<2x2x3x4x7x11xi8> to
+/// vector<8x6x7x11xi8>
+/// ^^^^ ---> common suffix of 7x11
+/// ^ ---> gcd(4,6) is 2 | |
+/// | | |
+/// v v v
+/// atomic shape <----- 2x7x11
+///
+///
+///
+/// The decomposition implemented in this pattern consists of a sequence of
+/// repeated steps:
+///
+/// (1) Extract vectors from the suffix of the source.
+/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
+///
+/// (2) Do extract_strided_slice down to the atomic shape.
+/// In our example this is 4x7x11 -> 2x7x11.
+///
+/// (3) Do insert_strided_slice to the suffix of the result.
+/// In our example this is 2x7x11 -> 6x7x11.
+///
+/// (4) insert these vectors into the result vector.
+/// In our example this is 6x7x11 -> 8x6x7x11.
+///
+/// These steps occur with different periods. In this example
+/// (1) occurs 12 times,
+/// (2) and (3) occur 24 times, and
+/// (4) occurs 8 times.
+///
+/// Two special cases are handled seperately:
+/// (1) A shape_cast that just does leading 1 insertion/removal
+/// (2) A shape_cast where the gcd is 1.
----------------
banach-space wrote:
Where are these cases handled?
https://github.com/llvm/llvm-project/pull/140800
More information about the Mlir-commits
mailing list