[Mlir-commits] [mlir] [mlir][vector] Migrate drop-lead-unit-dim to shape_cast (PR #196206)

Diego Caballero llvmlistbot at llvm.org
Fri May 8 10:06:44 PDT 2026


================
@@ -45,14 +48,89 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-/// Return a smallVector of size `rank` containing all zeros.
-static SmallVector<int64_t> splatZero(int64_t rank) {
-  return SmallVector<int64_t>(rank, 0);
+/// Returns `value` if it already has `newType`, otherwise inserts a
+/// vector.shape_cast to `newType`.
+static Value shapeCastVector(OpBuilder &b, Location loc, Value value,
+                             VectorType newType) {
+  if (value.getType() == newType)
+    return value;
+  return vector::ShapeCastOp::create(b, loc, newType, value);
 }
+
+static bool hasNonScalableUnitLeadingDims(VectorType type, int64_t dropCount) {
+  assert(dropCount >= 0 && dropCount <= type.getRank() &&
+         "expected a valid leading dimension count");
+  ArrayRef<int64_t> leadingShape = type.getShape().take_front(dropCount);
+  ArrayRef<bool> leadingScalable = type.getScalableDims().take_front(dropCount);
+  return llvm::all_of(leadingShape, [](int64_t dim) { return dim == 1; }) &&
+         llvm::none_of(leadingScalable, [](bool scalable) { return scalable; });
+}
+
+static bool isNonScalableUnitDim(VectorType type, int64_t dim) {
+  return type.getShape()[dim] == 1 && !type.getScalableDims()[dim];
+}
+
+static VectorType transposeVectorType(VectorType type,
+                                      ArrayRef<int64_t> permutation) {
+  return VectorType::get(applyPermutation(type.getShape(), permutation),
+                         type.getElementType(),
+                         applyPermutation(type.getScalableDims(), permutation));
+}
+
+/// Shape-casts `operand` to the vector type obtained by dropping the first
+/// `dropCount` dimensions. Callers must ensure at least one vector dimension
+/// remains after the drop.
+static Value shapeCastDroppingLeadingDims(OpBuilder &b, Location loc,
+                                          Value operand, int64_t dropCount) {
+  auto oldType = cast<VectorType>(operand.getType());
+  assert(dropCount < oldType.getRank() &&
+         "shape_cast cannot drop all vector dimensions");
+  VectorType newType = VectorType::get(
+      oldType.getShape().drop_front(dropCount), oldType.getElementType(),
+      oldType.getScalableDims().drop_front(dropCount));
+  return shapeCastVector(b, loc, operand, newType);
+}
+
+static Value shapeCastDroppingDim(OpBuilder &b, Location loc, Value operand,
+                                  int64_t dim) {
+  auto oldType = cast<VectorType>(operand.getType());
+  assert(isNonScalableUnitDim(oldType, dim) &&
+         "expected a non-scalable unit dim to drop");
+
+  SmallVector<int64_t> newShape;
+  SmallVector<bool> newScalableDims;
+  for (int64_t i = 0, e = oldType.getRank(); i < e; ++i) {
+    if (i == dim)
+      continue;
+    newShape.push_back(oldType.getShape()[i]);
+    newScalableDims.push_back(oldType.getScalableDims()[i]);
+  }
+
+  return shapeCastVector(
+      b, loc, operand,
+      VectorType::get(newShape, oldType.getElementType(), newScalableDims));
+}
+
+static Value dropLeadingDimsForContraction(OpBuilder &b, Location loc,
+                                           Value operand, int64_t dropCount) {
+  auto oldType = cast<VectorType>(operand.getType());
+  assert(hasNonScalableUnitLeadingDims(oldType, dropCount) &&
+         "expected non-scalable leading unit dims to drop");
+
+  // vector.contract rejects 0-D vector accumulators/results. When every vector
+  // dimension is dropped, use the scalar path that vector.contract accepts.
+  if (dropCount == oldType.getRank()) {
+    llvm::Repeated<int64_t> zeros(static_cast<size_t>(dropCount), 0);
----------------
dcaballe wrote:

Couldn't we just use SmallVector initialized to zero?

https://github.com/llvm/llvm-project/pull/196206


More information about the Mlir-commits mailing list