[Mlir-commits] [mlir] 97043e5 - [mlir][Vector][GPU] Distribute expanding `shape_cast` ops (#183830)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 05:10:45 PST 2026


Author: Artem Gindinson
Date: 2026-03-03T14:10:40+01:00
New Revision: 97043e50ad41c9da94a5cb48417f5139d6a84c8d

URL: https://github.com/llvm/llvm-project/commit/97043e50ad41c9da94a5cb48417f5139d6a84c8d
DIFF: https://github.com/llvm/llvm-project/commit/97043e50ad41c9da94a5cb48417f5139d6a84c8d.diff

LOG: [mlir][Vector][GPU] Distribute expanding `shape_cast` ops (#183830)

The initial implementation of `shape_cast` distribution only focused on
scenarios with collapsing shape casts. Within downstream pipelines such
as IREE, commit 962a9a3 exposes an issue with this implementation, where
the rank-expanding cast ops (stemming from the new `vector.broadcast`
canonicalization) silently fall through to the "collapsing-or-no-op"
logic. This brings about bugs with rank mismatches and firing validation
assertions when distributing rather common reshaping sequences
encountered after CSE/ canonicalization, such as below:
```
  // Example 1: gather op
  %weight = arith.constant dense_resource<__elided__> : tensor<256xi8>
  %c0 = arith.constant 0 : index
  ...
  %expand = vector.shape_cast <...> : vector<1xindex> to vector<1x1xindex>
  %gather = vector.gather %weight[%c0] [%expand], <...>, <...> : memref<256xi8>, vector<1x1xindex>, vector<1x1xi1>, vector<1x1xi8> into vector<1x1xi8>
  %collapse_back = vector.shape_cast %gather : vector<1x1xi8> to vector<1xi8>
  // Example 2: multi-reduction
  %expand = vector.shape_cast <...>: vector<1x96xi32> to vector<1x2x48xi32>
  %reduce = vector.multi_reduction <add>, %expand, <...> [1, 2]: vector<1x2x48xi32> to vector<1x1xi32>
  %collapse = vector.shape_cast %reduce : vector<1x1xi32> to vector<1xi32>
```

This commit adds initial handling of expanding `shape_cast`s, going
through the three scenarios:
- if all "excess" dimensions in the front of the destination shape are
unit, it's clear that the work is not distributed across those, so we
strip the same number of unit dimensions from the per-lane yielded type;
- if the source type within the warp code is of rank 1, we still
determine the corresponding output type cleanly by multiplying the
dimensions of the per-lane yield type;
- if neither of the above are true, explicitly fail the pattern for such
expanding `shape_cast`'s. Dimension-specific distribution parameters are
deemed ambiguous, at least from within this pattern's context.

---------

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b496711df8bb8..b4d500212c770 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1077,16 +1077,11 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
     VectorType castOriginalType = oldCastOp.getSourceVectorType();
     VectorType castResultType = castDistributedType;
 
-    // We expect the distributed type to have a smaller rank than the original
-    // type. Prepend with size-one dimensions to make them the same.
-    unsigned castDistributedRank = castDistributedType.getRank();
-    unsigned castOriginalRank = castOriginalType.getRank();
-    if (castDistributedRank < castOriginalRank) {
-      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
-      llvm::append_range(shape, castDistributedType.getShape());
-      castDistributedType =
-          VectorType::get(shape, castDistributedType.getElementType());
-    }
+    FailureOr<VectorType> maybeSrcType =
+        inferDistributedSrcType(castDistributedType, castOriginalType);
+    if (failed(maybeSrcType))
+      return failure();
+    castDistributedType = *maybeSrcType;
 
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1099,6 +1094,46 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
     return success();
   }
+
+private:
+  static FailureOr<VectorType>
+  inferDistributedSrcType(VectorType distributedType, VectorType srcType) {
+    unsigned distributedRank = distributedType.getRank();
+    unsigned srcRank = srcType.getRank();
+    if (distributedRank == srcRank)
+      // Nothing to do.
+      return distributedType;
+    if (distributedRank < srcRank) {
+      // If the distributed type has a smaller rank than the original type,
+      // prepend with unit dimensions to make the types the same length.
+      SmallVector<int64_t> shape(srcRank - distributedRank, 1);
+      llvm::append_range(shape, distributedType.getShape());
+      return VectorType::get(shape, distributedType.getElementType());
+    }
+    // Handle the expanding shape_cast's.
+    //
+    // If the casted-from type has one rank, we can assert that the element
+    // count in that rank will match the full thread-level element count of
+    // the yielded type.
+    // Note that getNumElements() will correctly "flatten" the shape of the
+    // specific shape_cast's distributed type (its distribution may be
+    // 
diff erent from the overall warp size, e.g. if the cast is applied to
+    // a result of a gather).
+    if (srcRank == 1)
+      return VectorType::get(distributedType.getNumElements(),
+                             srcType.getElementType());
+    // Try to strip leading unit dimensions to match the ranks. We bail out
+    // for more complex tile sizes, because those would require us to
+    // determine the specific distribution parameters to threads, which is
+    // unfeasible within this pattern.
+    unsigned excessDims = distributedRank - srcRank;
+    ArrayRef<int64_t> shape = distributedType.getShape();
+    if (!llvm::all_of(shape.take_front(excessDims),
+                      [](int64_t d) { return d == 1; }))
+      return failure();
+    return VectorType::get(shape.drop_front(excessDims),
+                           distributedType.getElementType());
+  }
 };
 
 /// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 63c9d9b7a9bf8..278f02ed033ab 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1604,6 +1604,74 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
 
 // -----
 
+func.func @warp_propagate_shape_cast_rank_extending(
+    %laneid: index, %src: memref<4096xf32>) -> vector<1x1x4xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[1024] -> (vector<1x1x4xf32>) {
+    %1 = vector.transfer_read %src[%c0], %cst : memref<4096xf32>, vector<4096xf32>
+    %2 = vector.shape_cast %1 : vector<4096xf32> to vector<32x4x32xf32>
+    gpu.yield %2 : vector<32x4x32xf32>
+  }
+  return %r : vector<1x1x4xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_rank_extending
+//  CHECK-PROP-NOT:    gpu.warp_execute_on_lane_0
+//      CHECK-PROP:    %[[READ:.+]] = vector.transfer_read {{.+}} : memref<4096xf32>, vector<4xf32>
+//      CHECK-PROP:    %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf32> to vector<1x1x4xf32>
+//      CHECK-PROP:    return %[[CAST]] : vector<1x1x4xf32>
+
+// -----
+
+// Shape cast from a single-rank source.
+// The per-lane source type is expected to be obtained by flattening the distributed result dims
+// (here, 1x2x4 = 8).
+func.func @warp_propagate_shape_cast_rank_extending_flat(
+    %laneid: index, %src: memref<128xf32>) -> vector<1x2x4xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x2x4xf32>) {
+    %1 = vector.transfer_read %src[%c0], %cst : memref<128xf32>, vector<128xf32>
+    %2 = vector.shape_cast %1 : vector<128xf32> to vector<4x4x8xf32>
+    gpu.yield %2 : vector<4x4x8xf32>
+  }
+  return %r : vector<1x2x4xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_rank_extending_flat
+//  CHECK-PROP-NOT:    gpu.warp_execute_on_lane_0
+//      CHECK-PROP:    %[[READ:.+]] = vector.transfer_read {{.+}} : memref<128xf32>, vector<8xf32>
+//      CHECK-PROP:    %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<8xf32> to vector<1x2x4xf32>
+//      CHECK-PROP:    return %[[CAST]] : vector<1x2x4xf32>
+
+// -----
+
+// Negative test: rank-2 source with a non-unit dimension.
+// The per-lane distribution across source dimensions is ambiguous, so the pattern under
+// test isn't expected to handle it.
+func.func @warp_do_not_propagate_shape_cast_rank_extending_ambiguous(
+    %laneid: index, %src: memref<32x4xf32>) -> vector<2x4x8xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[2] -> (vector<2x4x8xf32>) {
+    %1 = vector.transfer_read %src[%c0, %c0], %cst : memref<32x4xf32>, vector<32x4xf32>
+    %2 = vector.shape_cast %1 : vector<32x4xf32> to vector<4x4x8xf32>
+    gpu.yield %2 : vector<4x4x8xf32>
+  }
+  return %r : vector<2x4x8xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_do_not_propagate_shape_cast_rank_extending_ambiguous
+//      CHECK-PROP:    gpu.warp_execute_on_lane_0
+//      CHECK-PROP:      vector.transfer_read {{.+}} : memref<32x4xf32>, vector<32x4xf32>
+//      CHECK-PROP:      vector.shape_cast {{.+}} : vector<32x4xf32> to vector<4x4x8xf32>
+//      CHECK-PROP:      gpu.yield {{.+}} : vector<4x4x8xf32>
+//      CHECK-PROP:    }
+//  CHECK-PROP-NOT:    vector.shape_cast
+
+// -----
+
 func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {
   %f0 = arith.constant 0.000000e+00 : f32
   %r = gpu.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {


        


More information about the Mlir-commits mailing list