[Mlir-commits] [mlir] [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (PR #162515)

Nishant Patel llvmlistbot at llvm.org
Tue Oct 14 09:03:21 PDT 2025


================
@@ -465,41 +465,65 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto targetShape = getTargetShape(options, op);
     if (!targetShape)
       return failure();
+    int64_t targetShapeRank = targetShape->size();
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    // Bail-out if rank(source) != rank(target). The main limitation here is the
-    // fact that `ExtractStridedSlice` requires the rank for the input and
-    // output to match. If needed, we can relax this later.
-    if (originalSize.size() != targetShape->size())
-      return rewriter.notifyMatchFailure(
-          op, "expected input vector rank to match target shape rank");
+    int64_t originalShapeRank = originalSize.size();
+
     Location loc = op->getLoc();
+
+    // Handle rank mismatch by adding leading unit dimensions to targetShape
+    SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+    int64_t rankDiff = originalShapeRank - targetShapeRank;
+    std::fill(adjustedTargetShape.begin(),
+              adjustedTargetShape.begin() + rankDiff, 1);
+    std::copy(targetShape->begin(), targetShape->end(),
+              adjustedTargetShape.begin() + rankDiff);
----------------
nbpatel wrote:

added a 4D to 2D test case

https://github.com/llvm/llvm-project/pull/162515


More information about the Mlir-commits mailing list