[Mlir-commits] [mlir] Revert "[mlir][vector] Migrate drop-lead-unit-dim to shape_cast (#196… (PR #199546)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 25 09:18:29 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Omair Javaid (omjavaid)

<details>
<summary>Changes</summary>

This reverts commit 24b8bb18f3417419cbd16fcd31f4e2842df952a1.

This broke AArch64 SVE Linux buildbots, however it was not reported due a glitch in the buildbot infrastructure. Following bots are failing:

https://lab.llvm.org/buildbot/#/builders/121
https://lab.llvm.org/buildbot/#/builders/41
https://lab.llvm.org/buildbot/#/builders/4
https://lab.llvm.org/buildbot/#/builders/199
https://lab.llvm.org/buildbot/#/builders/17
https://lab.llvm.org/buildbot/#/builders/198
https://lab.llvm.org/buildbot/#/builders/143

---

Patch is 107.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/199546.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+176-272) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+18-20) 
- (modified) mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir (+7-23) 
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+149-281) 
- (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (+6-8) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index aad42039300e3..26a702ef0f512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -7,9 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <numeric>
-#include <utility>
 
-#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -17,7 +15,6 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/Repeated.h"
 #include "llvm/ADT/STLExtras.h"
 
 #define DEBUG_TYPE "vector-drop-unit-dim"
@@ -25,9 +22,9 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-// Trims leading unit dimensions from `oldType` and returns the result type.
-static VectorType trimLeadingUnitDims(VectorType oldType,
-                                      bool zeroDimsAllowed) {
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
   ArrayRef<int64_t> newShape = oldShape;
 
@@ -40,117 +37,22 @@ static VectorType trimLeadingUnitDims(VectorType oldType,
     newScalableDims = newScalableDims.drop_front(1);
   }
 
-  // Some vector ops forbid 0-D vectors.
-  if (!zeroDimsAllowed && newShape.empty()) {
+  // Make sure we have at least 1 dimension per vector type requirements.
+  if (newShape.empty()) {
     newShape = oldShape.take_back();
     newScalableDims = oldType.getScalableDims().take_back();
   }
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-static bool isNonScalableUnitDim(VectorType type, int64_t dim) {
-  assert(dim >= 0 && dim < type.getRank() &&
-         "expected a valid vector dimension");
-  return type.getShape()[dim] == 1 && !type.getScalableDims()[dim];
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+  return SmallVector<int64_t>(rank, 0);
 }
-
-/// Returns true if the first `k` dimensions of `type` are non-scalable unit
-/// dimensions.
-static bool areLeadingDimsUnit(VectorType type, int64_t k) {
-  assert(k >= 0 && k <= type.getRank() &&
-         "expected a valid leading dimension count");
-  return llvm::all_of(llvm::seq<int64_t>(0, k), [&](int64_t dim) {
-    return isNonScalableUnitDim(type, dim);
-  });
-}
-
-static bool areLeadingDimsUnitAfterPermutation(VectorType type,
-                                               ArrayRef<int64_t> permutation,
-                                               int64_t k) {
-  assert(k >= 0 && k <= static_cast<int64_t>(permutation.size()) &&
-         "expected a valid leading dimension count");
-  return llvm::all_of(permutation.take_front(k), [&](int64_t dim) {
-    return isNonScalableUnitDim(type, dim);
-  });
-}
-
-/// Shape-casts `operand` to the vector type obtained by dropping dimension
-/// `dim`, which must be non-scalable and unit-sized.
-static Value dropUnitDim(OpBuilder &b, Location loc, Value operand,
-                         int64_t dimToDrop, bool zeroDimsAllowed) {
-  auto oldType = cast<VectorType>(operand.getType());
-  assert(isNonScalableUnitDim(oldType, dimToDrop) &&
-         "expected a non-scalable unit dim to drop");
-  int64_t rank = oldType.getRank();
-  assert((zeroDimsAllowed || rank > 1) &&
-         "target op does not allow 0-D vectors");
-
-  SmallVector<int64_t> newShape;
-  SmallVector<bool> newScalableDims;
-  newShape.reserve(rank - 1);
-  newScalableDims.reserve(rank - 1);
-  for (auto [i, size, scalable] :
-       llvm::enumerate(oldType.getShape(), oldType.getScalableDims())) {
-    if (static_cast<int64_t>(i) == dimToDrop)
-      continue;
-    newShape.push_back(size);
-    newScalableDims.push_back(scalable);
-  }
-
-  return b.createOrFold<vector::ShapeCastOp>(
-      loc, VectorType::get(newShape, oldType.getElementType(), newScalableDims),
-      operand);
-}
-
-/// Shape-casts `operand` to the vector type obtained by dropping the first
-/// `k` non-scalable unit dimensions.
-static Value dropLeadingUnitDims(OpBuilder &b, Location loc, Value operand,
-                                 int64_t k, bool zeroDimsAllowed) {
-  auto oldType = cast<VectorType>(operand.getType());
-  assert(areLeadingDimsUnit(oldType, k) &&
-         "expected non-scalable leading unit dims to drop");
-  assert((zeroDimsAllowed || k < oldType.getRank()) &&
-         "target op does not allow 0-D vectors");
-  VectorType newType = VectorType::get(oldType.getShape().drop_front(k),
-                                       oldType.getElementType(),
-                                       oldType.getScalableDims().drop_front(k));
-  return b.createOrFold<vector::ShapeCastOp>(loc, newType, operand);
-}
-
-/// Returns the vector type obtained by applying `permutation` to `type`.
-static VectorType permuteVectorType(VectorType type,
-                                    ArrayRef<int64_t> permutation) {
-  assert(static_cast<int64_t>(permutation.size()) == type.getRank() &&
-         "expected a permutation matching the operand rank");
-  SmallVector<int64_t> permutedShape =
-      applyPermutation(type.getShape(), permutation);
-  SmallVector<bool> permutedScalableDims =
-      applyPermutation(type.getScalableDims(), permutation);
-  return VectorType::get(permutedShape, type.getElementType(),
-                         permutedScalableDims);
-}
-
-/// Like `dropLeadingUnitDims` except that if all dimensions would be dropped,
-/// the single element inside that vector is extracted and returned.
-static Value dropLeadingUnitDims0DIsScalar(OpBuilder &b, Location loc,
-                                           Value operand, int64_t k) {
-  auto oldType = cast<VectorType>(operand.getType());
-  assert(areLeadingDimsUnit(oldType, k) &&
-         "expected non-scalable leading unit dims to drop");
-
-  if (k == oldType.getRank()) {
-    SmallVector<int64_t> zeros(k, static_cast<int64_t>(0));
-    return vector::ExtractOp::create(b, loc, operand, zeros);
-  }
-
-  return dropLeadingUnitDims(b, loc, operand, k,
-                             /*zeroDimsAllowed=*/true);
-}
-
 namespace {
 
 // Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
+// input by inserting vector.broadcast.
 struct CastAwayExtractStridedSliceLeadingOneDim
     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
   using Base::Base;
@@ -161,8 +63,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     // the same rank. Here we drop leading one dimensions from the input vector
     // type to make sure we don't cause mismatch.
     VectorType oldSrcType = extractOp.getSourceVectorType();
-    VectorType newSrcType =
-        trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false);
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
 
     if (newSrcType.getRank() == oldSrcType.getRank())
       return failure();
@@ -177,8 +78,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
 
     Location loc = extractOp.getLoc();
 
-    Value newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newSrcType, extractOp.getSource());
+    Value newSrcVector = vector::ExtractOp::create(
+        rewriter, loc, extractOp.getSource(), splatZero(dropCount));
 
     // The offsets/sizes/strides attribute can have a less number of elements
     // than the input vector's rank: it is meant for the leading dimensions.
@@ -193,7 +94,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
         rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
         newStrides);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
                                                      newExtractOp);
 
     return success();
@@ -201,7 +102,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
 };
 
 // Casts away leading one dimensions in vector.insert_strided_slice's vector
-// inputs by inserting vector.shape_cast.
+// inputs by inserting vector.broadcast.
 struct CastAwayInsertStridedSliceLeadingOneDim
     : public OpRewritePattern<vector::InsertStridedSliceOp> {
   using Base::Base;
@@ -209,11 +110,9 @@ struct CastAwayInsertStridedSliceLeadingOneDim
   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
                                 PatternRewriter &rewriter) const override {
     VectorType oldSrcType = insertOp.getSourceVectorType();
-    VectorType newSrcType =
-        trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false);
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
     VectorType oldDstType = insertOp.getDestVectorType();
-    VectorType newDstType =
-        trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/false);
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
 
     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
@@ -223,10 +122,10 @@ struct CastAwayInsertStridedSliceLeadingOneDim
     // Trim leading one dimensions from both operands.
     Location loc = insertOp.getLoc();
 
-    Value newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newSrcType, insertOp.getValueToStore());
-    Value newDstVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newDstType, insertOp.getDest());
+    Value newSrcVector = vector::ExtractOp::create(
+        rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+    Value newDstVector = vector::ExtractOp::create(
+        rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
 
     auto newOffsets = rewriter.getArrayAttr(
         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
@@ -237,7 +136,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
         rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
         newStrides);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);
 
     return success();
@@ -245,7 +144,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
 };
 
 // Casts away leading one dimensions in vector.insert's vector inputs by
-// inserting vector.shape_cast.
+// inserting vector.broadcast.
 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
   using Base::Base;
 
@@ -255,14 +154,13 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     Type newSrcType = oldSrcType;
     int64_t oldSrcRank = 0, newSrcRank = 0;
     if (auto type = dyn_cast<VectorType>(oldSrcType)) {
-      newSrcType = trimLeadingUnitDims(type, /*zeroDimsAllowed=*/false);
+      newSrcType = trimLeadingOneDims(type);
       oldSrcRank = type.getRank();
       newSrcRank = cast<VectorType>(newSrcType).getRank();
     }
 
     VectorType oldDstType = insertOp.getDestVectorType();
-    VectorType newDstType =
-        trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/oldSrcRank == 0);
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
 
     int64_t srcDropCount = oldSrcRank - newSrcRank;
     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
@@ -273,11 +171,12 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     Location loc = insertOp.getLoc();
 
     Value newSrcVector = insertOp.getValueToStore();
-    if (oldSrcRank != 0)
-      newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
-          loc, cast<VectorType>(newSrcType), insertOp.getValueToStore());
-    Value newDstVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newDstType, insertOp.getDest());
+    if (oldSrcRank != 0) {
+      newSrcVector = vector::ExtractOp::create(
+          rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+    }
+    Value newDstVector = vector::ExtractOp::create(
+        rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
 
     // New position rank needs to be computed in two steps: (1) if destination
     // type has leading unit dims, we also trim the position array accordingly,
@@ -294,7 +193,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
                                                 newDstVector, newPosition);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);
 
     return success();
@@ -302,10 +201,20 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
 };
 
 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
-                                  VectorType newType, AffineMap newMap) {
+                                  VectorType newType, AffineMap newMap,
+                                  VectorType oldMaskType) {
   // Infer the type of the new mask from the new map.
   VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
-  return b.createOrFold<vector::ShapeCastOp>(loc, newMaskType, mask);
+
+  // If the new mask is broadcastable to the old result type, we can safely
+  // use a `vector.extract` to get the new mask. Otherwise the best we can
+  // do is shape cast.
+  if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+      BroadcastableToResult::Success) {
+    int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+    return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
+  }
+  return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
 }
 
 // Turns vector.transfer_read on vector with leading 1 dimensions into
@@ -320,7 +229,7 @@ struct CastAwayTransferReadLeadingOneDim
     // TODO(#78787): Not supported masked op yet.
     if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
       return failure();
-    // Nothing to trim when the transfer itself has rank zero.
+    // TODO: support 0-d corner case.
     if (read.getTransferRank() == 0)
       return failure();
 
@@ -329,7 +238,7 @@ struct CastAwayTransferReadLeadingOneDim
       return failure();
 
     VectorType oldType = read.getVectorType();
-    VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true);
+    VectorType newType = trimLeadingOneDims(oldType);
 
     if (newType == oldType)
       return failure();
@@ -347,14 +256,16 @@ struct CastAwayTransferReadLeadingOneDim
           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
 
     Value mask = Value();
-    if (read.getMask())
+    if (read.getMask()) {
+      VectorType maskType = read.getMaskType();
       mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
-                                  newType, newMap);
+                                  newType, newMap, maskType);
+    }
 
     auto newRead = vector::TransferReadOp::create(
         rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
         AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
 
     return success();
   }
@@ -372,7 +283,7 @@ struct CastAwayTransferWriteLeadingOneDim
     // TODO(#78787): Not supported masked op yet.
     if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
       return failure();
-    // Nothing to trim when the transfer itself has rank zero.
+    // TODO: support 0-d corner case.
     if (write.getTransferRank() == 0)
       return failure();
 
@@ -381,9 +292,11 @@ struct CastAwayTransferWriteLeadingOneDim
       return failure();
 
     VectorType oldType = write.getVectorType();
-    VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true);
+    VectorType newType = trimLeadingOneDims(oldType);
     if (newType == oldType)
       return failure();
+    int64_t dropDim = oldType.getRank() - newType.getRank();
+
     AffineMap oldMap = write.getPermutationMap();
     ArrayRef<AffineExpr> newResults =
         oldMap.getResults().take_back(newType.getRank());
@@ -396,12 +309,13 @@ struct CastAwayTransferWriteLeadingOneDim
       inBoundsAttr = rewriter.getArrayAttr(
           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
 
-    auto newVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        write.getLoc(), newType, write.getVector());
+    auto newVector = vector::ExtractOp::create(
+        rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
 
     if (write.getMask()) {
-      Value newMask = dropUnitDimsFromMask(rewriter, write.getLoc(),
-                                           write.getMask(), newType, newMap);
+      VectorType maskType = write.getMaskType();
+      Value newMask = dropUnitDimsFromMask(
+          rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
           write, newVector, write.getBase(), write.getIndices(),
           AffineMapAttr::get(newMap), newMask, inBoundsAttr);
@@ -417,15 +331,6 @@ struct CastAwayTransferWriteLeadingOneDim
 
 } // namespace
 
-namespace {
-struct VectorContractOperandCastPlan {
-  AffineMap map;
-  SmallVector<int64_t> permutation;
-  bool dropLeadingUnitDim = false;
-  bool permuteOperand = false;
-};
-} // namespace
-
 FailureOr<Value>
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                MaskingOpInterface maskingOp,
@@ -435,7 +340,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
     return failure();
   if (oldAccType.getRank() < 1)
     return failure();
-  if (!isNonScalableUnitDim(oldAccType, 0))
+  if (oldAccType.getShape()[0] != 1)
     return failure();
   // currently we support only dropping one dim but the pattern can be applied
   // greedily to drop more.
@@ -462,70 +367,74 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
 
   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
                                  contractOp.getAcc()};
-  SmallVector<VectorContractOperandCastPlan> operandCastPlans;
   SmallVector<Value> newOperands;
   auto loc = contractOp.getLoc();
 
-  if (maskingOp) {
-    auto oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
-    if (oldMaskType.getRank() <= 1 || dimToDrop >= oldMaskType.getRank() ||
-        !isNonScalableUnitDim(oldMaskType, dimToDrop))
-      return failure();
-  }
-
   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
     // Check if the dim to be dropped exists as a leading dim in the operand
-    // if it does then we use vector.shape_cast to drop it.
-    VectorContractOperandCastPlan plan;
+    // if it does then we use vector.extract to drop it.
+    bool validExtract = false;
     SmallVector<AffineExpr> results;
-    plan.map = it.value();
-    int64_t originalZeroDim = plan.map.getDimPosition(0);
-    if (originalZeroDim != dimToDrop) {
+    auto map = it.value();
+    int64_t orginalZeroDim = it.value().getDimPosition(0);
+    if (orginalZeroDim != dimToDrop) {
       // There are two reasons to be in this path, 1. We need to
-      // permute the operand type to make the dim to be dropped
+      // transpose the operand to make the dim to be dropped
       // leading. 2. The dim to be dropped does not exist and in
-      // that case we dont want to add a unit permutation but we must
+      // that case we dont want to add a unit transpose but we must
       // check all the indices to make sure this is the case.
-      SmallVector<AffineExpr> permutedResult...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/199546


More information about the Mlir-commits mailing list