[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)

Diego Caballero llvmlistbot at llvm.org
Wed Jun 4 12:34:42 PDT 2025


================
@@ -21,177 +21,259 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include <numeric>
 
 #define DEBUG_TYPE "vector-shape-cast-lowering"
 
 using namespace mlir;
-using namespace mlir::vector;
-
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
-                   int step = 1) {
-  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    assert(indices[dim] < vecType.getDimSize(dim) &&
-           "Indices are out of bound");
-    indices[dim] += step;
-    if (indices[dim] < vecType.getDimSize(dim))
-      break;
 
-    indices[dim] = 0;
-    step = 1;
+/// Perform the inplace update
+///    rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1  --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1  --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+                       MutableArrayRef<int64_t> rhs) {
+
+  // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+  for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+    int64_t dimBase = base[dim];
+    assert(rhs[dim] < dimBase && "rhs not in base");
+
+    int64_t incremented = rhs[dim] + lhs;
+
+    // If the incremented value excedes the dimension base, we must spill to the
+    // next most significant dimension and repeat (we might need to spill to
+    // more significant dimensions multiple times).
+    lhs = incremented / dimBase;
+    rhs[dim] = incremented % dimBase;
+    if (lhs == 0)
+      break;
   }
 }
 
 namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
+
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+///         vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
+///
+///         vector<2x2x3x4x7x11xi8> to
+///             vector<8x6x7x11xi8>
+///                      | ||||
+///                      | ++++------------> common suffix of 7x11
+///                      +----------------->    gcd(4,6) is 2 | |
+///                                                         | | |
+///                                                         v v v
+///                                 atomic shape   <-----   2x7x11
----------------
dcaballe wrote:

Nice! This makes sense to me...

https://github.com/llvm/llvm-project/pull/140800


More information about the Mlir-commits mailing list