[Mlir-commits] [mlir] 7cc27e2 - [MLIR][Vector] Enhance shape_cast unrolling support in case the target shape is [1, 1, ..1] (#183436)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 07:11:32 PST 2026
Author: Jianhui Li
Date: 2026-02-27T07:11:27-08:00
New Revision: 7cc27e28db97a9b2fd145d35044b7dbe6c8426f2
URL: https://github.com/llvm/llvm-project/commit/7cc27e28db97a9b2fd145d35044b7dbe6c8426f2
DIFF: https://github.com/llvm/llvm-project/commit/7cc27e28db97a9b2fd145d35044b7dbe6c8426f2.diff
LOG: [MLIR][Vector] Enhance shape_cast unrolling support in case the target shape is [1, 1, ..1] (#183436)
This PR fixes a minor issue in shape_cast unrolling: when all target
dimensions are unit-sized, it no longer removes all leading unit
dimensions.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b62ce8a2ec398..37a691c4fce7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1196,14 +1196,14 @@ struct UnrollConstantMaskPattern
static bool isContiguous(ArrayRef<int64_t> extractShape,
ArrayRef<int64_t> shape) {
- if (extractShape.size() > shape.size())
+ if (extractShape.empty() || shape.empty() ||
+ extractShape.size() > shape.size())
return false;
- while (!extractShape.empty() && extractShape.front() == 1) {
+ while (extractShape.size() > 1 && extractShape.front() == 1)
extractShape = extractShape.drop_front();
- }
- while (!shape.empty() && shape.front() == 1) {
+ while (shape.size() > 1 && shape.front() == 1) {
shape = shape.drop_front();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 49b66d2a8f6f6..206f52a6c71cc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -214,7 +214,8 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
return getTileShape(op->getOpOperand(0));
if (isa<vector::TransposeOp, vector::BroadcastOp, vector::StepOp,
- vector::ConstantMaskOp, vector::CreateMaskOp>(op))
+ vector::ShapeCastOp, vector::ConstantMaskOp, vector::CreateMaskOp>(
+ op))
return getTileShape(op->getOpResult(0));
return std::nullopt;
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index c2e7f6a9338b1..14bc81a06c098 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -647,3 +647,21 @@ func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> {
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
// CHECK: return %[[I1]] : vector<1x32xf32>
+
+
+// TargetShape is [1x1]
+func.func @shape_cast_with_all_unit_target_shape(%v: vector<2xf32>) -> vector<2x1xf32> {
+ %0 = vector.shape_cast %v : vector<2xf32> to vector<2x1xf32>
+ return %0 : vector<2x1xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_with_all_unit_target_shape
+// CHECK-SAME: (%[[V:.*]]: vector<2xf32>) -> vector<2x1xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x1xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x1xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x1xf32>
+// CHECK: return %[[I1]] : vector<2x1xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 3317ae8d11b0d..8005d1ff76d53 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -202,6 +202,10 @@ struct TestVectorUnrollingPatterns
resultShape[1] == 32) {
return SmallVector<int64_t>{1, 16};
}
+ if (resultShape.size() == 2 && resultShape[0] == 2 &&
+ resultShape[1] == 1) {
+ return SmallVector<int64_t>{1, 1};
+ }
// Default case: [2,4] for all tests.
return SmallVector<int64_t>{2, 4};
})
More information about the Mlir-commits
mailing list