[Mlir-commits] [mlir] [mlir][Vector][GPU] Distribute expanding `shape_cast` ops (PR #183830)

Artem Gindinson llvmlistbot at llvm.org
Sun Mar 1 23:29:42 PST 2026


https://github.com/AGindinson updated https://github.com/llvm/llvm-project/pull/183830

>From 92d1fc2805009c84357c88fafb29bd8a96ee9c7c Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 27 Feb 2026 11:10:45 +0000
Subject: [PATCH 1/2] [mlir][Vector][GPU] Distribute expanding `shape_cast` ops

The initial implementation of `shape_cast` distribution only focused on
scenarios with collapsing shape casts. Within downstream pipelines such as
IREE, commit 962a9a3 exposes an issue with this implementation, where the
rank-expanding cast ops (stemming from the new `vector.broadcast`
canonicalization) silently fall through to the "collapsing-or-no-op" logic.
This brings about bugs with rank mismatches and firing validation assertions
when distributing rather common reshaping sequences encountered after CSE/
canonicalization, such as below:
```
  // Example 1: gather op
  %weight = arith.constant dense_resource<__elided__> : tensor<256xi8>
  %c0 = arith.constant 0 : index
  ...
  %expand = vector.shape_cast <...> : vector<1xindex> to vector<1x1xindex>
  %gather = vector.gather %weight[%c0] [%expand], <...>, <...> : memref<256xi8>, vector<1x1xindex>, vector<1x1xi1>, vector<1x1xi8> into vector<1x1xi8>
  %collapse_back = vector.shape_cast %gather : vector<1x1xi8> to vector<1xi8>
  // Example 2: multi-reduction
  %expand = vector.shape_cast <...>: vector<1x96xi32> to vector<1x2x48xi32>
  %reduce = vector.multi_reduction <add>, %expand, <...> [1, 2]: vector<1x2x48xi32> to vector<1x1xi32>
  %collapse = vector.shape_cast %reduce : vector<1x1xi32> to vector<1xi32>
```

This commit adds initial handling of expanding `shape_cast`s, going through
the three scenarios:
- if all "excess" dimensions in the front of the destination shape are unit,
  it's clear that the work is not distributed across those, so we strip the
  same number of unit dimensions from the per-lane yielded type;
- if the source type within the warp code is of rank 1, we still determine
  the corresponding output type cleanly by multiplying the dimensions of
  the per-lane yield type;
- if neither of the above are true, explicitly fail the pattern for such
  expanding `shape_cast`'s. Dimension-specific distribution parameters
  are deemed ambiguous, at least from within this pattern's context.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 .../Vector/Transforms/VectorDistribute.cpp    | 55 ++++++++++++---
 .../Vector/vector-warp-distribute.mlir        | 68 +++++++++++++++++++
 2 files changed, 113 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 743fb51bab1ab..eb5bd76338e65 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1077,16 +1077,11 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
     VectorType castOriginalType = oldCastOp.getSourceVectorType();
     VectorType castResultType = castDistributedType;
 
-    // We expect the distributed type to have a smaller rank than the original
-    // type. Prepend with size-one dimensions to make them the same.
-    unsigned castDistributedRank = castDistributedType.getRank();
-    unsigned castOriginalRank = castOriginalType.getRank();
-    if (castDistributedRank < castOriginalRank) {
-      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
-      llvm::append_range(shape, castDistributedType.getShape());
-      castDistributedType =
-          VectorType::get(shape, castDistributedType.getElementType());
-    }
+    FailureOr<VectorType> maybeSrcType =
+        inferDistributedSrcType(castDistributedType, castOriginalType);
+    if (failed(maybeSrcType))
+      return failure();
+    castDistributedType = *maybeSrcType;
 
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1099,6 +1094,46 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
     return success();
   }
+
+private:
+  static FailureOr<VectorType>
+  inferDistributedSrcType(VectorType distributedType, VectorType srcType) {
+    unsigned distributedRank = distributedType.getRank();
+    unsigned srcRank = srcType.getRank();
+    if (distributedRank == srcRank)
+      // Nothing to do.
+      return distributedType;
+    if (distributedRank < srcRank) {
+      // If the distributed type has a smaller rank than the original type,
+      // prepend with unit dimensions to make the types the same length.
+      SmallVector<int64_t> shape(srcRank - distributedRank, 1);
+      llvm::append_range(shape, distributedType.getShape());
+      return VectorType::get(shape, distributedType.getElementType());
+    }
+    // Handle the expanding shape_cast's.
+    //
+    // If the casted-from type has one rank, we can assert that the element
+    // count in that rank will match the full thread-level element count of
+    // the yielded type.
+    // Note that getNumElements() will correctly "flatten" the shape of the
+    // specific shape_cast's distributed type (its distribution may be
+    // different from the overall warp size, e.g. if the cast is applied to
+    // a result of a gather).
+    if (srcRank == 1)
+      return VectorType::get({distributedType.getNumElements()},
+                             srcType.getElementType());
+    // Try to strip leading unit dimensions to match the ranks. We bail out
+    // for more complex tile sizes, because those would require us to
+    // determine the specific distribution parameters to threads, which is
+    // unfeasible within this pattern.
+    unsigned excessDims = distributedRank - srcRank;
+    ArrayRef<int64_t> shape = distributedType.getShape();
+    if (!llvm::all_of(shape.take_front(excessDims),
+                      [](int64_t d) { return d == 1; }))
+      return failure();
+    return VectorType::get(shape.drop_front(excessDims),
+                           distributedType.getElementType());
+  }
 };
 
 /// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 63c9d9b7a9bf8..278f02ed033ab 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1604,6 +1604,74 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
 
 // -----
 
+func.func @warp_propagate_shape_cast_rank_extending(
+    %laneid: index, %src: memref<4096xf32>) -> vector<1x1x4xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[1024] -> (vector<1x1x4xf32>) {
+    %1 = vector.transfer_read %src[%c0], %cst : memref<4096xf32>, vector<4096xf32>
+    %2 = vector.shape_cast %1 : vector<4096xf32> to vector<32x4x32xf32>
+    gpu.yield %2 : vector<32x4x32xf32>
+  }
+  return %r : vector<1x1x4xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_rank_extending
+//  CHECK-PROP-NOT:    gpu.warp_execute_on_lane_0
+//      CHECK-PROP:    %[[READ:.+]] = vector.transfer_read {{.+}} : memref<4096xf32>, vector<4xf32>
+//      CHECK-PROP:    %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf32> to vector<1x1x4xf32>
+//      CHECK-PROP:    return %[[CAST]] : vector<1x1x4xf32>
+
+// -----
+
+// Shape cast from a single-rank source.
+// The per-lane source type is expected to be obtained by flattening the distributed result dims
+// (here, 1x2x4 = 8).
+func.func @warp_propagate_shape_cast_rank_extending_flat(
+    %laneid: index, %src: memref<128xf32>) -> vector<1x2x4xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x2x4xf32>) {
+    %1 = vector.transfer_read %src[%c0], %cst : memref<128xf32>, vector<128xf32>
+    %2 = vector.shape_cast %1 : vector<128xf32> to vector<4x4x8xf32>
+    gpu.yield %2 : vector<4x4x8xf32>
+  }
+  return %r : vector<1x2x4xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_rank_extending_flat
+//  CHECK-PROP-NOT:    gpu.warp_execute_on_lane_0
+//      CHECK-PROP:    %[[READ:.+]] = vector.transfer_read {{.+}} : memref<128xf32>, vector<8xf32>
+//      CHECK-PROP:    %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<8xf32> to vector<1x2x4xf32>
+//      CHECK-PROP:    return %[[CAST]] : vector<1x2x4xf32>
+
+// -----
+
+// Negative test: rank-2 source with a non-unit dimension.
+// The per-lane distribution across source dimensions is ambiguous, so the pattern under
+// test isn't expected to handle it.
+func.func @warp_do_not_propagate_shape_cast_rank_extending_ambiguous(
+    %laneid: index, %src: memref<32x4xf32>) -> vector<2x4x8xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[2] -> (vector<2x4x8xf32>) {
+    %1 = vector.transfer_read %src[%c0, %c0], %cst : memref<32x4xf32>, vector<32x4xf32>
+    %2 = vector.shape_cast %1 : vector<32x4xf32> to vector<4x4x8xf32>
+    gpu.yield %2 : vector<4x4x8xf32>
+  }
+  return %r : vector<2x4x8xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_do_not_propagate_shape_cast_rank_extending_ambiguous
+//      CHECK-PROP:    gpu.warp_execute_on_lane_0
+//      CHECK-PROP:      vector.transfer_read {{.+}} : memref<32x4xf32>, vector<32x4xf32>
+//      CHECK-PROP:      vector.shape_cast {{.+}} : vector<32x4xf32> to vector<4x4x8xf32>
+//      CHECK-PROP:      gpu.yield {{.+}} : vector<4x4x8xf32>
+//      CHECK-PROP:    }
+//  CHECK-PROP-NOT:    vector.shape_cast
+
+// -----
+
 func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {
   %f0 = arith.constant 0.000000e+00 : f32
   %r = gpu.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {

>From 252fa02896f967ebbc5113e23a60053d877d033b Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Sat, 28 Feb 2026 10:52:59 +0100
Subject: [PATCH 2/2] [fixup] drop obsolete initializer list

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index eb5bd76338e65..2a2b7714a84b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1120,7 +1120,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
     // different from the overall warp size, e.g. if the cast is applied to
     // a result of a gather).
     if (srcRank == 1)
-      return VectorType::get({distributedType.getNumElements()},
+      return VectorType::get(distributedType.getNumElements(),
                              srcType.getElementType());
     // Try to strip leading unit dimensions to match the ranks. We bail out
     // for more complex tile sizes, because those would require us to



More information about the Mlir-commits mailing list