[Mlir-commits] [mlir] [mlir][vector] Migrate drop-lead-unit-dim to shape_cast (PR #196206)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri May 8 04:06:54 PDT 2026


================
@@ -45,14 +48,89 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-/// Return a smallVector of size `rank` containing all zeros.
-static SmallVector<int64_t> splatZero(int64_t rank) {
-  return SmallVector<int64_t>(rank, 0);
+/// Returns `value` if it already has `newType`, otherwise inserts a
+/// vector.shape_cast to `newType`.
+static Value shapeCastVector(OpBuilder &b, Location loc, Value value,
+                             VectorType newType) {
+  if (value.getType() == newType)
+    return value;
+  return vector::ShapeCastOp::create(b, loc, newType, value);
 }
+
+static bool hasNonScalableUnitLeadingDims(VectorType type, int64_t dropCount) {
+  assert(dropCount >= 0 && dropCount <= type.getRank() &&
+         "expected a valid leading dimension count");
+  ArrayRef<int64_t> leadingShape = type.getShape().take_front(dropCount);
+  ArrayRef<bool> leadingScalable = type.getScalableDims().take_front(dropCount);
+  return llvm::all_of(leadingShape, [](int64_t dim) { return dim == 1; }) &&
+         llvm::none_of(leadingScalable, [](bool scalable) { return scalable; });
+}
----------------
banach-space wrote:

[nit] IMHO, `dropCount` is too generic and doesn't compose well `hasNonScalableUnitLeadingDims` (it's hard to map to `dropCount` to what `hasNonScalableUnitLeadingDims` implies to be doing).
```suggestion
static bool hasNonScalableUnitLeadingDims(VectorType type, int64_t leadingDimsToCheckCount) {
  assert(leadingDimsToCheckCount >= 0 && leadingDimsToCheckCount <= type.getRank() &&
         "expected a valid leading dimension count");
  ArrayRef<int64_t> leadingShape = type.getShape().take_front(leadingDimsToCheckCount);
  ArrayRef<bool> leadingScalable = type.getScalableDims().take_front(leadingDimsToCheckCount);
  return llvm::all_of(leadingShape, [](int64_t dim) { return dim == 1; }) &&
         llvm::none_of(leadingScalable, [](bool scalable) { return scalable; });
}
```

https://github.com/llvm/llvm-project/pull/196206


More information about the Mlir-commits mailing list