[Mlir-commits] [mlir] 7cc27e2 - [MLIR][Vector] Enhance shape_cast unrolling support in case the target shape is [1, 1, ..1] (#183436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 27 07:11:32 PST 2026


Author: Jianhui Li
Date: 2026-02-27T07:11:27-08:00
New Revision: 7cc27e28db97a9b2fd145d35044b7dbe6c8426f2

URL: https://github.com/llvm/llvm-project/commit/7cc27e28db97a9b2fd145d35044b7dbe6c8426f2
DIFF: https://github.com/llvm/llvm-project/commit/7cc27e28db97a9b2fd145d35044b7dbe6c8426f2.diff

LOG: [MLIR][Vector] Enhance shape_cast unrolling support in case the target shape is [1, 1, ..1]  (#183436)

This PR fixes a minor issue in shape_cast unrolling: when all target
dimensions are unit-sized, it no longer removes all leading unit
dimensions.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b62ce8a2ec398..37a691c4fce7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1196,14 +1196,14 @@ struct UnrollConstantMaskPattern
 static bool isContiguous(ArrayRef<int64_t> extractShape,
                          ArrayRef<int64_t> shape) {
 
-  if (extractShape.size() > shape.size())
+  if (extractShape.empty() || shape.empty() ||
+      extractShape.size() > shape.size())
     return false;
 
-  while (!extractShape.empty() && extractShape.front() == 1) {
+  while (extractShape.size() > 1 && extractShape.front() == 1)
     extractShape = extractShape.drop_front();
-  }
 
-  while (!shape.empty() && shape.front() == 1) {
+  while (shape.size() > 1 && shape.front() == 1) {
     shape = shape.drop_front();
   }
 

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 49b66d2a8f6f6..206f52a6c71cc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -214,7 +214,8 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
     return getTileShape(op->getOpOperand(0));
 
   if (isa<vector::TransposeOp, vector::BroadcastOp, vector::StepOp,
-          vector::ConstantMaskOp, vector::CreateMaskOp>(op))
+          vector::ShapeCastOp, vector::ConstantMaskOp, vector::CreateMaskOp>(
+          op))
     return getTileShape(op->getOpResult(0));
 
   return std::nullopt;

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index c2e7f6a9338b1..14bc81a06c098 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -647,3 +647,21 @@ func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> {
 // CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
 // CHECK:   return %[[I1]] : vector<1x32xf32>
+
+
+// TargetShape is [1x1]
+func.func @shape_cast_with_all_unit_target_shape(%v: vector<2xf32>) -> vector<2x1xf32> {
+  %0 = vector.shape_cast %v : vector<2xf32> to vector<2x1xf32>
+  return %0 : vector<2x1xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_with_all_unit_target_shape
+// CHECK-SAME: (%[[V:.*]]: vector<2xf32>) -> vector<2x1xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x1xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<1xf32> to vector<1x1xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x1xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<1xf32> to vector<1x1xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x1xf32>
+// CHECK:   return %[[I1]] : vector<2x1xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 3317ae8d11b0d..8005d1ff76d53 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -202,6 +202,10 @@ struct TestVectorUnrollingPatterns
                       resultShape[1] == 32) {
                     return SmallVector<int64_t>{1, 16};
                   }
+                  if (resultShape.size() == 2 && resultShape[0] == 2 &&
+                      resultShape[1] == 1) {
+                    return SmallVector<int64_t>{1, 1};
+                  }
                   // Default case: [2,4] for all tests.
                   return SmallVector<int64_t>{2, 4};
                 })


        


More information about the Mlir-commits mailing list