[Mlir-commits] [mlir] [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (PR #162515)

Nishant Patel llvmlistbot at llvm.org
Thu Oct 9 10:32:40 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/162515

>From f854b2dcf748a0ff9e5e39d5fae70e3cd437f774 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 8 Oct 2025 16:37:37 +0000
Subject: [PATCH 1/2] Extend elementwise to support unrolling from higher rank
 to lower rank

---
 .../Vector/Transforms/VectorUnroll.cpp        | 52 +++++++++++-----
 .../Dialect/Vector/vector-unroll-options.mlir | 60 +++++++++++++++++--
 2 files changed, 92 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..62d65e28e8c2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -468,23 +468,30 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    // Bail-out if rank(source) != rank(target). The main limitation here is the
-    // fact that `ExtractStridedSlice` requires the rank for the input and
-    // output to match. If needed, we can relax this later.
-    if (originalSize.size() != targetShape->size())
-      return rewriter.notifyMatchFailure(
-          op, "expected input vector rank to match target shape rank");
+
     Location loc = op->getLoc();
+
+    // Handle rank mismatch by adding leading unit dimensions to targetShape
+    SmallVector<int64_t> adjustedTargetShape = *targetShape;
+    SmallVector<int64_t> adjustedOffsets;
+    if (originalSize.size() > targetShape->size()) {
+      // Add leading unit dimensions to targetShape
+      int64_t rankDiff = originalSize.size() - targetShape->size();
+      adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
+    }
+
     // Prepare the result vector.
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    VectorType newVecType =
+    SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+    VectorType extractVecType =
+        VectorType::get(adjustedTargetShape, dstVecType.getElementType());
+    VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
     // Create the unrolled computation.
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalSize, *targetShape)) {
+         StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
       SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
           extractOperands.push_back(operand.get());
           continue;
         }
-        extractOperands.push_back(
-            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-                loc, operand.get(), offsets, *targetShape, strides));
+        Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+        // Reshape to remove leading unit dims if needed
+        if (adjustedTargetShape.size() > targetShape->size()) {
+          extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+              loc, VectorType::get(*targetShape, vecType.getElementType()),
+              extracted);
+        }
+        extractOperands.push_back(extracted);
       }
+
       Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, newVecType);
+          rewriter, loc, op, extractOperands, computeVecType);
+
+      Value computeResult = newOp->getResult(0);
+
+      // Reshape back to higher rank if needed for insertion
+      if (adjustedTargetShape.size() > targetShape->size()) {
+        computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
+            loc, extractVecType, computeResult);
+      }
+
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, newOp->getResult(0), result, offsets, strides);
+          loc, computeResult, result, offsets, strides);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..a26e4b0baa05b 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,40 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
   %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
   return %0 : vector<3x2x2xf32>
 }
-// CHECK-LABEL: func @negative_vector_fma_3d
-//   CHECK-NOT: vector.extract_strided_slice
-//       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+//       CHECK:   %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +465,26 @@ func.func @vector_step() -> vector<32xindex> {
 // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
 // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
 // CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   return %[[I1]] : vector<2x2x2xf32>

>From 01357d9b7dc5f0707ab8220bb7eb561935f88d35 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Oct 2025 17:26:21 +0000
Subject: [PATCH 2/2] remove reshape before insertStride

---
 .../Dialect/Vector/Transforms/VectorUnroll.cpp    | 14 ++++++--------
 .../Dialect/Vector/vector-unroll-options.mlir     | 15 +++++----------
 2 files changed, 11 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 62d65e28e8c2e..dd01873ded05e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -484,8 +484,6 @@ struct UnrollElementwisePattern : public RewritePattern {
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
     SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
-    VectorType extractVecType =
-        VectorType::get(adjustedTargetShape, dstVecType.getElementType());
     VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
@@ -516,14 +514,14 @@ struct UnrollElementwisePattern : public RewritePattern {
 
       Value computeResult = newOp->getResult(0);
 
-      // Reshape back to higher rank if needed for insertion
-      if (adjustedTargetShape.size() > targetShape->size()) {
-        computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
-            loc, extractVecType, computeResult);
-      }
+      // Use strides sized to targetShape for proper insertion
+      SmallVector<int64_t> insertStrides =
+          (adjustedTargetShape.size() > targetShape->size())
+              ? SmallVector<int64_t>(targetShape->size(), 1)
+              : strides;
 
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, computeResult, result, offsets, strides);
+          loc, computeResult, result, offsets, insertStrides);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index a26e4b0baa05b..feea7c0f74531 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -201,8 +201,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 //       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
-//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
@@ -210,8 +209,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 //       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
-//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
@@ -219,8 +217,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 //       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
-//       CHECK:   %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
@@ -478,13 +475,11 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
 //       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
-//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
 //       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
-//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
 //       CHECK:   return %[[I1]] : vector<2x2x2xf32>



More information about the Mlir-commits mailing list