[Mlir-commits] [mlir] [MLIR][Vector] Allow unrolling to drop leading unit dimensions (PR #164710)

Nishant Patel llvmlistbot at llvm.org
Wed Oct 22 13:58:46 PDT 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/164710

This PR relaxes check in getTargetShape to allow unrolling to kick in for reducing dimensionality. It allows cases like [1, 1, N] to [N] to be unrolled (more like shape_cast). 

>From 4707325ff13b83f58aed99d386ec9037691d9563 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 22 Oct 2025 20:34:57 +0000
Subject: [PATCH] Allow unrolling to drop leading unit dimensions

---
 .../Vector/Transforms/VectorUnroll.cpp        | 23 +++++++++++++++++--
 .../Dialect/Vector/vector-unroll-options.mlir | 15 ++++++++++++
 2 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..09f10b3ac952d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -122,8 +122,27 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
     return std::nullopt;
   }
   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
-    LDBG() << "--no unrolling needed -> SKIP";
-    return std::nullopt;
+    // If maybeShapeRatio are all 1s, only allow unrolling for leading unit
+    // dimension removal: [1,1,...,n] -> [n]
+    if (maybeUnrollShape->size() <= targetShape->size()) {
+      LDBG() << "--no dimension reduction -> SKIP";
+      return std::nullopt;
+    }
+
+    size_t dimDiff = maybeUnrollShape->size() - targetShape->size();
+    ArrayRef<int64_t> srcShape = *maybeUnrollShape;
+    ArrayRef<int64_t> tgtShape = *targetShape;
+
+    // Check leading dimensions are 1s and remaining matches target
+    bool isValidRemoval = llvm::all_of(srcShape.slice(0, dimDiff),
+                                       [](int64_t dim) { return dim == 1; }) &&
+                          srcShape.slice(dimDiff) == tgtShape;
+
+    if (!isValidRemoval) {
+      LDBG() << "--not a valid leading unit dimension removal -> SKIP";
+      return std::nullopt;
+    }
+    LDBG() << "--leading unit dimension removal -> CONTINUE";
   }
   LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
   return targetShape;
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..9fd77645b78b5 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,18 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
 // CHECK-NOT: arith.addf
 // CHECK: return
+
+
+func.func @elementwise_leading_unit_dim(%v1: vector<1x2x2xf32>, %v2: vector<1x2x2xf32>) -> vector<1x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<1x2x2xf32>
+  return %0 : vector<1x2x2xf32>
+}
+
+// CHECK-LABEL: func @elementwise_leading_unit_dim
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x2x2xf32>, %[[ARG1:.*]]: vector<1x2x2xf32>) -> vector<1x2x2xf32> {
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[S_RHS:.*]] = vector.shape_cast %[[ARG1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD:.*]] = arith.addf %[[S_LHS]], %[[S_RHS]] : vector<2x2xf32>
+//       CHECK:   %[[INS:.*]] = vector.insert_strided_slice %[[ADD]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<1x2x2xf32>
+//       CHECK:   return %[[INS]] : vector<1x2x2xf32>



More information about the Mlir-commits mailing list