[Mlir-commits] [mlir] b4c31dc - [mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements (#142944)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 19 05:43:34 PDT 2025


Author: Yang Bai
Date: 2025-08-19T13:43:31+01:00
New Revision: b4c31dc98dfc929728904cd96f0f4cf812c4d5b5

URL: https://github.com/llvm/llvm-project/commit/b4c31dc98dfc929728904cd96f0f4cf812c4d5b5
DIFF: https://github.com/llvm/llvm-project/commit/b4c31dc98dfc929728904cd96f0f4cf812c4d5b5.diff

LOG: [mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements (#142944)

## Description

This change introduces a new canonicalization pattern for the MLIR
Vector dialect that optimizes chains of insertions. The optimization
identifies when a vector is **completely** initialized through a series
of vector.insert operations and replaces the entire chain with a
single `vector.from_elements` operation.

Please be aware that the new pattern **doesn't** work for poison vectors
where only **some** elements are set, as MLIR doesn't support partial
poison vectors for now.

**New Pattern: InsertChainFullyInitialized**

* Detects chains of vector.insert operations.
* Validates that all insertions are at static positions, and all
intermediate insertions have only one use.
* Ensures the entire vector is **completely** initialized.
* Replaces the entire chain with a
single vector.from_elementts operation.

**Refactored Helper Function**

* Extracted `calculateInsertPosition` from
`foldDenseElementsAttrDestInsertOp` to avoid code duplication.

## Example

```
// Before:
%v1 = vector.insert %c10, %v0[0] : i64 into vector<2xi64>
%v2 = vector.insert %c20, %v1[1] : i64 into vector<2xi64>

// After:
%v2 = vector.from_elements %c10, %c20 : vector<2xi64>
```

It also works for multidimensional vectors.

```
// Before:
%v1 = vector.insert %cv0, %v0[0] : vector<3xi64> into vector<2x3xi64>
%v2 = vector.insert %cv1, %v1[1] : vector<3xi64> into vector<2x3xi64>

// After:
%0:3 = vector.to_elements %arg1 : vector<3xi64>
%1:3 = vector.to_elements %arg2 : vector<3xi64>
%v2 = vector.from_elements %0#0, %0#1, %0#2, %1#0, %1#1, %1#2 : vector<2x3xi64>
```

---------

Co-authored-by: Yang Bai <yangb at nvidia.com>
Co-authored-by: Andrzej Warzyński <andrzej.warzynski at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/vector-gather-lowering.mlir
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 74e48b59b6460..2b2581d353673 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3286,6 +3286,18 @@ LogicalResult InsertOp::verify() {
   return success();
 }
 
+// Calculate the linearized position of the continuous chunk of elements to
+// insert, based on the shape of the value to insert and the positions to insert
+// at.
+static int64_t calculateInsertPosition(VectorType destTy,
+                                       ArrayRef<int64_t> positions) {
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  assert(positions.size() <= completePositions.size() &&
+         "positions size must be less than or equal to destTy rank");
+  copy(positions, completePositions.begin());
+  return linearize(completePositions, computeStrides(destTy.getShape()));
+}
+
 namespace {
 
 // If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3323,6 +3335,132 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
     return success();
   }
 };
+
+/// Pattern to optimize a chain of insertions.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Only insert values at static positions.
+/// 2. Completely initialize all elements in the resulting vector.
+/// 3. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single vector.from_elements operation.
+///
+/// To keep this pattern simple, and avoid spending too much time on matching
+/// fragmented insert chains, this pattern only considers the last insert op in
+/// the chain.
+///
+/// Example transformation:
+///   %poison = ub.poison : vector<2xi32>
+///   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+///   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+///   %result = vector.from_elements %c1, %c2 : vector<2xi32>
+class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorType destTy = op.getDestVectorType();
+    if (destTy.isScalable())
+      return failure();
+    // Ensure this is the trailing vector.insert op in a chain of inserts.
+    for (Operation *user : op.getResult().getUsers())
+      if (auto insertOp = dyn_cast<InsertOp>(user))
+        if (insertOp.getDest() == op.getResult())
+          return failure();
+
+    InsertOp currentOp = op;
+    SmallVector<InsertOp> chainInsertOps;
+    while (currentOp) {
+      // Check cond 1: Dynamic position is not supported.
+      if (currentOp.hasDynamicPosition())
+        return failure();
+
+      chainInsertOps.push_back(currentOp);
+      currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
+      // Check cond 3: Intermediate inserts have only one use to avoid an
+      // explosion of vectors.
+      if (currentOp && !currentOp->hasOneUse())
+        return failure();
+    }
+
+    int64_t vectorSize = destTy.getNumElements();
+    int64_t initializedCount = 0;
+    SmallVector<bool> initializedDestIdxs(vectorSize, false);
+    SmallVector<int64_t> pendingInsertPos;
+    SmallVector<int64_t> pendingInsertSize;
+    SmallVector<Value> pendingInsertValues;
+
+    for (auto insertOp : chainInsertOps) {
+      // This pattern can do nothing with poison index.
+      if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
+        return failure();
+
+      // Calculate the linearized position for inserting elements.
+      int64_t insertBeginPosition =
+          calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+      // The valueToStore operand may be a vector or a scalar. Need to handle
+      // both cases.
+      int64_t insertSize = 1;
+      if (auto srcVectorType =
+              llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
+        insertSize = srcVectorType.getNumElements();
+
+      assert(insertBeginPosition + insertSize <= vectorSize &&
+             "insert would overflow the vector");
+
+      for (auto index : llvm::seq<int64_t>(insertBeginPosition,
+                                           insertBeginPosition + insertSize)) {
+        if (initializedDestIdxs[index])
+          continue;
+        initializedDestIdxs[index] = true;
+        ++initializedCount;
+      }
+
+      // Defer the creation of ops before we can make sure the pattern can
+      // succeed.
+      pendingInsertPos.push_back(insertBeginPosition);
+      pendingInsertSize.push_back(insertSize);
+      pendingInsertValues.push_back(insertOp.getValueToStore());
+
+      if (initializedCount == vectorSize)
+        break;
+    }
+
+    // Check cond 2: all positions must be initialized.
+    if (initializedCount != vectorSize)
+      return failure();
+
+    SmallVector<Value> elements(vectorSize);
+    for (auto [insertBeginPosition, insertSize, valueToStore] :
+         llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
+                                 pendingInsertValues))) {
+      auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
+
+      if (!srcVectorType) {
+        elements[insertBeginPosition] = valueToStore;
+        continue;
+      }
+
+      SmallVector<Type> elementToInsertTypes(insertSize,
+                                             srcVectorType.getElementType());
+      // Get all elements from the vector in row-major order.
+      auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
+          op.getLoc(), elementToInsertTypes, valueToStore);
+      for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
+        elements[insertBeginPosition + linearIdx] =
+            elementsToInsert.getResult(linearIdx);
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
+    return success();
+  }
+};
+
 } // namespace
 
 static Attribute
@@ -3349,13 +3487,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
       !insertOp->hasOneUse())
     return {};
 
-  // Calculate the linearized position of the continuous chunk of elements to
-  // insert.
-  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-  copy(insertOp.getStaticPosition(), completePositions.begin());
+  // Calculate the linearized position for inserting elements.
   int64_t insertBeginPosition =
-      linearize(completePositions, computeStrides(destTy.getShape()));
-
+      calculateInsertPosition(destTy, insertOp.getStaticPosition());
   SmallVector<Attribute> insertedValues;
   Type destEltType = destTy.getElementType();
 
@@ -3391,7 +3525,8 @@ static Value foldInsertUseChain(InsertOp insertOp) {
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+              InsertChainFullyInitialized>(context);
 }
 
 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index d68ba44ee8840..c85f4334ff2e5 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -83,20 +83,16 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
 // CHECK-LABEL: @transpose
 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
 func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
-  // CHECK: %[[UB:.*]] = ub.poison : vector<2xi32>
   // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[UB]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
+  // CHECK: %[[FROM_ELEMENTS0:.*]] = vector.from_elements %[[EXTRACT0]], %[[EXTRACT1]] : vector<2xi32>
   // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[UB]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
+  // CHECK: %[[FROM_ELEMENTS1:.*]] = vector.from_elements %[[EXTRACT2]], %[[EXTRACT3]] : vector<2xi32>
   // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[UB]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
-  // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
+  // CHECK: %[[FROM_ELEMENTS2:.*]] = vector.from_elements %[[EXTRACT4]], %[[EXTRACT5]] : vector<2xi32>
+  // CHECK: return %[[FROM_ELEMENTS0]], %[[FROM_ELEMENTS1]], %[[FROM_ELEMENTS2]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
   %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
   return %0 : vector<3x2xi32>
 }

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 08354dbf280c1..26b54566cb2cd 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -79,21 +79,17 @@ func.func @absf_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-LABEL:   func @absf_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @fabsf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @fabsf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @fabs(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @fabs(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @absf_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.absf %float : vector<2xf32>
@@ -116,21 +112,17 @@ func.func @acos_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-LABEL:   func @acos_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @acosf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @acosf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @acos(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @acos(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @acos_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.acos %float : vector<2xf32>
@@ -153,21 +145,17 @@ func.func @acosh_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-LABEL:   func @acosh_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @acoshf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @acoshf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @acosh(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @acosh(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @acosh_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.acosh %float : vector<2xf32>
@@ -190,21 +178,17 @@ func.func @asin_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-LABEL:   func @asin_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @asinf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @asinf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @asin(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @asin(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @asin_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.asin %float : vector<2xf32>
@@ -227,21 +211,17 @@ func.func @asinh_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-LABEL:   func @asinh_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @asinhf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @asinhf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @asinh(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @asinh(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @asinh_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.asinh %float : vector<2xf32>
@@ -274,21 +254,17 @@ func.func @atan_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) ->
 // CHECK-LABEL:   func @atan_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @atanf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @atanf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @atan(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @atan(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @atan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.atan %float : vector<2xf32>
@@ -321,21 +297,17 @@ func.func @atanh_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) ->
 // CHECK-LABEL:   func @atanh_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @atanhf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @atanhf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @atanh(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @atanh(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @atanh_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.atanh %float : vector<2xf32>
@@ -419,23 +391,19 @@ func.func @erf_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
 func.func @erf_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
   // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
   // CHECK:           %[[OUT0_F32:.*]] = call @erff(%[[IN0_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
   // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
   // CHECK:           %[[OUT1_F32:.*]] = call @erff(%[[IN1_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  // CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
   %float_result = math.erf %float : vector<2xf32>
   // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
   // CHECK:           %[[OUT0_F64:.*]] = call @erf(%[[IN0_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
   // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
   // CHECK:           %[[OUT1_F64:.*]] = call @erf(%[[IN1_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  // CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
   %double_result = math.erf %double : vector<2xf64>
-  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  // CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
@@ -459,21 +427,17 @@ func.func @exp_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
 // CHECK-LABEL:   func @exp_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @expf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @expf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @exp(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @exp(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @exp2_caller
@@ -496,21 +460,17 @@ func.func @exp2_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
 // CHECK-LABEL:   func @exp2_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @exp2f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @exp2f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @exp2(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @exp2(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @log_caller
@@ -533,21 +493,17 @@ func.func @log_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
 // CHECK-LABEL:   func @log_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @logf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @logf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @log(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @log(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @log2_caller
@@ -570,21 +526,17 @@ func.func @log2_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
 // CHECK-LABEL:   func @log2_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @log2f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @log2f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @log2(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @log2(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @log10_caller
@@ -607,21 +559,17 @@ func.func @log10_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
 // CHECK-LABEL:   func @log10_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @log10f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @log10f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @log10(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @log10(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @expm1_caller
@@ -644,21 +592,17 @@ func.func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
 // CHECK-LABEL:   func @expm1_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32>) {
@@ -667,20 +611,16 @@ func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32
 }
 // CHECK-LABEL:   func @expm1_multidim_vec_caller(
 // CHECK-SAME:                           %[[VAL:.*]]: vector<2x2xf32>
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
 // CHECK:           %[[IN0_0_F32:.*]] = vector.extract %[[VAL]][0, 0] : f32 from vector<2x2xf32>
 // CHECK:           %[[OUT0_0_F32:.*]] = call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_1:.*]] = vector.insert %[[OUT0_0_F32]], %[[CVF]] [0, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[IN0_1_F32:.*]] = vector.extract %[[VAL]][0, 1] : f32 from vector<2x2xf32>
 // CHECK:           %[[OUT0_1_F32:.*]] = call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_2:.*]] = vector.insert %[[OUT0_1_F32]], %[[VAL_1]] [0, 1] : f32 into vector<2x2xf32>
 // CHECK:           %[[IN1_0_F32:.*]] = vector.extract %[[VAL]][1, 0] : f32 from vector<2x2xf32>
 // CHECK:           %[[OUT1_0_F32:.*]] = call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_3:.*]] = vector.insert %[[OUT1_0_F32]], %[[VAL_2]] [1, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[IN1_1_F32:.*]] = vector.extract %[[VAL]][1, 1] : f32 from vector<2x2xf32>
 // CHECK:           %[[OUT1_1_F32:.*]] = call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
-// CHECK:           return %[[VAL_4]] : vector<2x2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_0_F32]], %[[OUT0_1_F32]], %[[OUT1_0_F32]], %[[OUT1_1_F32]] : vector<2x2xf32>
+// CHECK:           return %[[RES_F32]] : vector<2x2xf32>
 // CHECK:         }
 
 // CHECK-LABEL: func @fma_caller(
@@ -704,29 +644,25 @@ func.func @fma_vec_caller(%float_a: vector<2xf32>, %float_b: vector<2xf32>, %flo
 // CHECK-SAME:                           %[[VAL_0A:.*]]: vector<2xf32>, %[[VAL_0B:.*]]: vector<2xf32>, %[[VAL_0C:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1A:.*]]: vector<2xf64>, %[[VAL_1B:.*]]: vector<2xf64>, %[[VAL_1C:.*]]: vector<2xf64>
 // CHECK-SAME:                           ) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32A:.*]] = vector.extract %[[VAL_0A]][0] : f32 from vector<2xf32>
 // CHECK:           %[[IN0_F32B:.*]] = vector.extract %[[VAL_0B]][0] : f32 from vector<2xf32>
 // CHECK:           %[[IN0_F32C:.*]] = vector.extract %[[VAL_0C]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @fmaf(%[[IN0_F32A]], %[[IN0_F32B]], %[[IN0_F32C]]) : (f32, f32, f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32A:.*]] = vector.extract %[[VAL_0A]][1] : f32 from vector<2xf32>
 // CHECK:           %[[IN1_F32B:.*]] = vector.extract %[[VAL_0B]][1] : f32 from vector<2xf32>
 // CHECK:           %[[IN1_F32C:.*]] = vector.extract %[[VAL_0C]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @fmaf(%[[IN1_F32A]], %[[IN1_F32B]], %[[IN1_F32C]]) : (f32, f32, f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64A:.*]] = vector.extract %[[VAL_1A]][0] : f64 from vector<2xf64>
 // CHECK:           %[[IN0_F64B:.*]] = vector.extract %[[VAL_1B]][0] : f64 from vector<2xf64>
 // CHECK:           %[[IN0_F64C:.*]] = vector.extract %[[VAL_1C]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @fma(%[[IN0_F64A]], %[[IN0_F64B]], %[[IN0_F64C]]) : (f64, f64, f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64A:.*]] = vector.extract %[[VAL_1A]][1] : f64 from vector<2xf64>
 // CHECK:           %[[IN1_F64B:.*]] = vector.extract %[[VAL_1B]][1] : f64 from vector<2xf64>
 // CHECK:           %[[IN1_F64C:.*]] = vector.extract %[[VAL_1C]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @fma(%[[IN1_F64A]], %[[IN1_F64B]], %[[IN1_F64C]]) : (f64, f64, f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @round_caller
@@ -814,23 +750,19 @@ func.func @sin_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
 func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
   // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
   // CHECK:           %[[OUT0_F32:.*]] = call @roundf(%[[IN0_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
   // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
   // CHECK:           %[[OUT1_F32:.*]] = call @roundf(%[[IN1_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  // CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
   %float_result = math.round %float : vector<2xf32>
   // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
   // CHECK:           %[[OUT0_F64:.*]] = call @round(%[[IN0_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
   // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
   // CHECK:           %[[OUT1_F64:.*]] = call @round(%[[IN1_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  // CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
   %double_result = math.round %double : vector<2xf64>
-  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  // CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
@@ -838,23 +770,19 @@ func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
 func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
   // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
   // CHECK:           %[[OUT0_F32:.*]] = call @roundevenf(%[[IN0_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
   // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
   // CHECK:           %[[OUT1_F32:.*]] = call @roundevenf(%[[IN1_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  // CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
   %float_result = math.roundeven %float : vector<2xf32>
   // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
   // CHECK:           %[[OUT0_F64:.*]] = call @roundeven(%[[IN0_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
   // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
   // CHECK:           %[[OUT1_F64:.*]] = call @roundeven(%[[IN1_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  // CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
   %double_result = math.roundeven %double : vector<2xf64>
-  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  // CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
@@ -862,23 +790,19 @@ func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
 func.func @trunc_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
   // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
   // CHECK:           %[[OUT0_F32:.*]] = call @truncf(%[[IN0_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
   // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
   // CHECK:           %[[OUT1_F32:.*]] = call @truncf(%[[IN1_F32]]) : (f32) -> f32
-  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  // CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
   %float_result = math.trunc %float : vector<2xf32>
   // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
   // CHECK:           %[[OUT0_F64:.*]] = call @trunc(%[[IN0_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
   // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
   // CHECK:           %[[OUT1_F64:.*]] = call @trunc(%[[IN1_F64]]) : (f64) -> f64
-  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  // CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
   %double_result = math.trunc %double : vector<2xf64>
-  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  // CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
@@ -907,21 +831,17 @@ func.func @tan_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) -> (
 // CHECK-LABEL:   func @tan_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @tanf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @tanf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @tan(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @tan(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
   %float_result = math.tan %float : vector<2xf32>
@@ -985,21 +905,17 @@ func.func @sqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
 // CHECK-LABEL:   func @sqrt_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @sqrtf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @sqrtf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @sqrt(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @sqrt(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @rsqrt_caller
@@ -1022,21 +938,17 @@ func.func @rsqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
 // CHECK-LABEL:   func @rsqrt_vec_caller(
 // CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @rsqrtf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @rsqrtf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @rsqrt(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @rsqrt(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
 // CHECK-LABEL: func @powf_caller(
@@ -1060,23 +972,19 @@ func.func @powf_vec_caller(%float_a: vector<2xf32>, %float_b: vector<2xf32>, %do
 // CHECK-SAME:                           %[[VAL_0A:.*]]: vector<2xf32>, %[[VAL_0B:.*]]: vector<2xf32>,
 // CHECK-SAME:                           %[[VAL_1A:.*]]: vector<2xf64>, %[[VAL_1B:.*]]: vector<2xf64>
 // CHECK-SAME:                           ) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
 // CHECK:           %[[IN0_F32A:.*]] = vector.extract %[[VAL_0A]][0] : f32 from vector<2xf32>
 // CHECK:           %[[IN0_F32B:.*]] = vector.extract %[[VAL_0B]][0] : f32 from vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @powf(%[[IN0_F32A]], %[[IN0_F32B]]) : (f32, f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
 // CHECK:           %[[IN1_F32A:.*]] = vector.extract %[[VAL_0A]][1] : f32 from vector<2xf32>
 // CHECK:           %[[IN1_F32B:.*]] = vector.extract %[[VAL_0B]][1] : f32 from vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @powf(%[[IN1_F32A]], %[[IN1_F32B]]) : (f32, f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
 // CHECK:           %[[IN0_F64A:.*]] = vector.extract %[[VAL_1A]][0] : f64 from vector<2xf64>
 // CHECK:           %[[IN0_F64B:.*]] = vector.extract %[[VAL_1B]][0] : f64 from vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @pow(%[[IN0_F64A]], %[[IN0_F64B]]) : (f64, f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
 // CHECK:           %[[IN1_F64A:.*]] = vector.extract %[[VAL_1A]][1] : f64 from vector<2xf64>
 // CHECK:           %[[IN1_F64B:.*]] = vector.extract %[[VAL_1B]][1] : f64 from vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @pow(%[[IN1_F64A]], %[[IN1_F64B]]) : (f64, f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:           %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK:           return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 4a7176e1f8d7d..c640ddea7507b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -625,40 +625,40 @@ func.func @insert_extract_transpose_2d(
 // -----
 
 // CHECK-LABEL: insert_extract_chain
-//  CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32>
+//  CHECK-SAME: %[[V334:[a-zA-Z0-9]*]]: vector<3x3x4xf32>
 //  CHECK-SAME: %[[V34:[a-zA-Z0-9]*]]: vector<3x4xf32>
 //  CHECK-SAME: %[[V4:[a-zA-Z0-9]*]]: vector<4xf32>
-func.func @insert_extract_chain(%v234: vector<2x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>)
+func.func @insert_extract_chain(%v334: vector<3x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>)
     -> (vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>) {
   // CHECK-NEXT: %[[A34:.*]] = vector.insert
-  %A34 = vector.insert %v34, %v234[0]: vector<3x4xf32> into vector<2x3x4xf32>
+  %A34 = vector.insert %v34, %v334[0]: vector<3x4xf32> into vector<3x3x4xf32>
   // CHECK-NEXT: %[[B34:.*]] = vector.insert
-  %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<2x3x4xf32>
+  %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<3x3x4xf32>
   // CHECK-NEXT: %[[A4:.*]] = vector.insert
-  %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<2x3x4xf32>
+  %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<3x3x4xf32>
   // CHECK-NEXT: %[[B4:.*]] = vector.insert
-  %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<2x3x4xf32>
+  %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<3x3x4xf32>
 
   // Case 2.a. [1, 1] == insertpos ([1, 1])
   // Match %A4 insertionpos and fold to its source(i.e. %V4).
-   %r0 = vector.extract %B4[1, 1]: vector<4xf32> from vector<2x3x4xf32>
+   %r0 = vector.extract %B4[1, 1]: vector<4xf32> from vector<3x3x4xf32>
 
   // Case 3.a. insertpos ([1]) is a prefix of [1, 0].
   // Traverse %B34 to its source(i.e. %V34@[*0*]).
   // CHECK-NEXT: %[[R1:.*]] = vector.extract %[[V34]][0]
-   %r1 = vector.extract %B34[1, 0]: vector<4xf32> from vector<2x3x4xf32>
+   %r1 = vector.extract %B34[1, 0]: vector<4xf32> from vector<3x3x4xf32>
 
   // Case 4. [1] is a prefix of insertpos ([1, 1]).
   // Cannot traverse %B4.
   // CHECK-NEXT: %[[R2:.*]] = vector.extract %[[B4]][1]
-   %r2 = vector.extract %B4[1]: vector<3x4xf32> from vector<2x3x4xf32>
+   %r2 = vector.extract %B4[1]: vector<3x4xf32> from vector<3x3x4xf32>
 
   // Case 5. [0] is disjoint from insertpos ([1, 1]).
   // Traverse %B4 to its dest(i.e. %A4@[0]).
   // Traverse %A4 to its dest(i.e. %B34@[0]).
   // Traverse %B34 to its dest(i.e. %A34@[0]).
   // Match %A34 insertionpos and fold to its source(i.e. %V34).
-   %r3 = vector.extract %B4[0]: vector<3x4xf32> from vector<2x3x4xf32>
+   %r3 = vector.extract %B4[0]: vector<3x4xf32> from vector<3x3x4xf32>
 
   // CHECK: return %[[V4]], %[[R1]], %[[R2]], %[[V34]]
   return %r0, %r1, %r2, %r3:
@@ -1057,8 +1057,8 @@ func.func @insert_fold_same_rank(%v: vector<2x2xf32>) -> vector<2x2xf32> {
 
 // CHECK-LABEL: func @insert_no_fold_scalar_to_0d(
 //  CHECK-SAME:     %[[v:.*]]: vector<f32>)
-//       CHECK:   %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
-//       CHECK:   return %[[extract]]
+//       CHECK:   %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+//       CHECK:   return %[[cst]]
 func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = vector.insert %cst, %v [] : f32 into vector<f32>
@@ -2669,6 +2669,112 @@ func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3
 
 // -----
 
+// +---------------------------------------------------------------------------
+// Tests for InsertChainFullyInitialized .
+// +---------------------------------------------------------------------------
+// This pattern should fire when all vector elements are overwritten by vector.insert
+// at static positions, replacing the insert chain with vector.from_elements.
+// CHECK-LABEL: func.func @fully_insert_scalar_to_vector(
+//  CHECK-SAME: %[[DEST:.+]]: vector<2xi64>, %[[SRC1:.+]]: i64, %[[SRC2:.+]]: i64)
+//       CHECK: %[[RES:.+]] = vector.from_elements %[[SRC1]], %[[SRC2]] : vector<2xi64>
+//  CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_scalar_to_vector(%dest : vector<2xi64>, %src1 : i64, %src2 : i64) -> vector<2xi64> {
+  %v1 = vector.insert %src1, %dest[0] : i64 into vector<2xi64>
+  %v2 = vector.insert %src2, %v1[1] : i64 into vector<2xi64>
+  return %v2 : vector<2xi64>
+}
+
+// -----
+
+// Same as the above test, but with vector insertions.
+// CHECK-LABEL: func.func @fully_insert_vector_to_vector(
+//  CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: vector<2xi64>, %[[SRC2:.+]]: vector<2xi64>)
+//       CHECK: %[[ELE1:.+]]:2 = vector.to_elements %[[SRC1]] : vector<2xi64>
+//       CHECK: %[[ELE2:.+]]:2 = vector.to_elements %[[SRC2]] : vector<2xi64>
+//       CHECK: %[[RES:.+]] = vector.from_elements %[[ELE1]]#0, %[[ELE1]]#1, %[[ELE2]]#0, %[[ELE2]]#1 : vector<2x2xi64>
+//  CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_vector_to_vector(%dest : vector<2x2xi64>, %src1 : vector<2xi64>, %src2 : vector<2xi64>) -> vector<2x2xi64> {
+  %v1 = vector.insert %src1, %dest[0] : vector<2xi64> into vector<2x2xi64>
+  %v2 = vector.insert %src2, %v1[1] : vector<2xi64> into vector<2x2xi64>
+  return %v2 : vector<2x2xi64>
+}
+
+// -----
+
+// Test InsertChainFullyInitialized pattern with overlapping insertions.
+// 1. The first op inserts %src2 at [0,0].
+// 2. The second op inserts %src1 at [0,0], [0,1], overwriting %src2 at [0,0].
+// 3. The third op inserts %src1 at [1,0], [1,1].
+// 4. The fourth op inserts %src2 at [1,1], overwriting the existing value at [1,1].
+// CHECK-LABEL: func.func @fully_insert_to_vector_overlap_1(
+//  CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: vector<2xi64>, %[[SRC2:.+]]: i64)
+//       CHECK: %[[ELE1:.+]]:2 = vector.to_elements %[[SRC1]] : vector<2xi64>
+//       CHECK: %[[ELE2:.+]]:2 = vector.to_elements %[[SRC1]] : vector<2xi64>
+//       CHECK: %[[RES:.+]] = vector.from_elements %[[ELE1]]#0, %[[ELE1]]#1, %[[ELE2]]#0, %[[SRC2]] : vector<2x2xi64>
+//  CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_to_vector_overlap_1(%dest : vector<2x2xi64>, %src1 : vector<2xi64>, %src2 : i64) -> vector<2x2xi64> {
+  %v0 = vector.insert %src2, %dest[0, 0] : i64 into vector<2x2xi64>
+  %v1 = vector.insert %src1, %v0[0] : vector<2xi64> into vector<2x2xi64>
+  %v2 = vector.insert %src1, %v1[1] : vector<2xi64> into vector<2x2xi64>
+  %v3 = vector.insert %src2, %v2[1, 1] : i64 into vector<2x2xi64>
+  return %v3 : vector<2x2xi64>
+}
+
+// -----
+
+// Test InsertChainFullyInitialized pattern with overlapping insertions.
+// The vector inserted at last should overwrite the previously inserted scalars.
+// CHECK-LABEL: func.func @fully_insert_to_vector_overlap_2(
+//  CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: i64, %[[SRC2:.+]]: i64, %[[SRC3:.+]]: vector<2xi64>, %[[SRC4:.+]]: vector<2xi64>)
+//       CHECK: %[[ELE1:.+]]:2 = vector.to_elements %[[SRC3]] : vector<2xi64>
+//       CHECK: %[[ELE2:.+]]:2 = vector.to_elements %[[SRC4]] : vector<2xi64>
+//       CHECK: %[[RES:.+]] = vector.from_elements %[[ELE1]]#0, %[[ELE1]]#1, %[[ELE2]]#0, %[[ELE2]]#1 : vector<2x2xi64>
+//  CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_to_vector_overlap_2(%dest : vector<2x2xi64>, %src1 : i64, %src2 : i64, %src3 : vector<2xi64>, %src4 : vector<2xi64>) -> vector<2x2xi64> {
+  %v0 = vector.insert %src1, %dest[0, 0] : i64 into vector<2x2xi64>
+  %v1 = vector.insert %src2, %v0[0, 1] : i64 into vector<2x2xi64>
+  %v2 = vector.insert %src3, %v1[0] : vector<2xi64> into vector<2x2xi64>
+  %v3 = vector.insert %src4, %v2[1] : vector<2xi64> into vector<2x2xi64>
+  return %v3 : vector<2x2xi64>
+}
+
+// -----
+
+// Negative test for InsertChainFullyInitialized pattern when only some elements are overwritten.
+// CHECK-LABEL: func.func @negative_partially_insert_vector_to_vector(
+//  CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: vector<2xi64>, %[[SRC2:.+]]: i64)
+//       CHECK: %[[V0:.+]] = vector.insert %[[SRC1]], %[[DEST]] [0] : vector<2xi64> into vector<2x2xi64>
+//       CHECK: %[[V1:.+]] = vector.insert %[[SRC2]], %[[V0]] [0, 0] : i64 into vector<2x2xi64>
+//       CHECK: return %[[V1]]
+func.func @negative_partially_insert_vector_to_vector(%dest : vector<2x2xi64>, %src1 : vector<2xi64>, %src2 : i64) -> vector<2x2xi64> {
+  %v1 = vector.insert %src1, %dest[0] : vector<2xi64> into vector<2x2xi64>
+  %v2 = vector.insert %src2, %v1[0, 0] : i64 into vector<2x2xi64>
+  return %v2 : vector<2x2xi64>
+}
+
+// -----
+
+// Negative test when intermediate results have more than one user.
+// CHECK-LABEL: func.func @negative_intermediate_insert_multiple_users(
+//  CHECK-SAME: %[[DEST:.+]]: vector<3xi64>, %[[SRC1:.+]]: i64, %[[SRC2:.+]]: i64, %[[SRC3:.+]]: i64, %[[SRC4:.+]]: i64)
+//       CHECK: %[[V0:.+]] = vector.insert %[[SRC1]], %[[DEST]] [0] : i64 into vector<3xi64>
+//       CHECK: %[[V1:.+]] = vector.insert %[[SRC2]], %[[V0]] [1] : i64 into vector<3xi64>
+//       CHECK: %[[V2:.+]] = vector.insert %[[SRC3]], %[[V1]] [2] : i64 into vector<3xi64>
+//       CHECK: %[[V3:.+]] = vector.insert %[[SRC4]], %[[V1]] [2] : i64 into vector<3xi64>
+func.func @negative_intermediate_insert_multiple_users(%dest : vector<3xi64>, %src1 : i64, %src2 : i64, %src3 : i64, %src4 : i64) -> (vector<3xi64>, vector<3xi64>) {
+  %v1 = vector.insert %src1, %dest[0] : i64 into vector<3xi64>
+  %v2 = vector.insert %src2, %v1[1] : i64 into vector<3xi64>
+  %v3_0 = vector.insert %src3, %v2[2] : i64 into vector<3xi64>
+  %v3_1 = vector.insert %src4, %v2[2] : i64 into vector<3xi64>
+  return %v3_0, %v3_1 : vector<3xi64>, vector<3xi64>
+}
+
+// +---------------------------------------------------------------------------
+// End of  Tests For InsertChainFullyInitialized.
+// +---------------------------------------------------------------------------
+
+// -----
+
 // CHECK-LABEL: func.func @insert_2d_splat_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>

diff  --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 9c2a508671e06..0e1bad62ce763 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -198,7 +198,7 @@ func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<
 // CANON-NOT:     scf.if
 // CANON:         tensor.extract
 // CANON:         tensor.extract
-// CANON:         [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : f32 into vector<2xf32>
+// CANON:         [[FINAL:%.+]] = vector.from_elements %{{.+}}, %{{.+}} : vector<2xf32>
 // CANON-NEXT:    return [[FINAL]] : vector<2xf32>
 func.func @gather_tensor_1d_all_set(%base: tensor<?xf32>, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
   %mask = arith.constant dense <true> : vector<2xi1>

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index c3ce7e9ca7fda..4d2c964a6df3c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1299,7 +1299,7 @@ func.func @vector_insert_1d_broadcast(%laneid: index, %pos: index) -> (vector<96
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
 //       CHECK-PROP:     gpu.yield %[[VEC]], %[[VAL]]
-//       CHECK-PROP:   vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector<f32>
+//       CHECK-PROP:   vector.broadcast %[[W]]#1 : f32 to vector<f32>
 func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
     %0 = "some_def"() : () -> (vector<f32>)


        


More information about the Mlir-commits mailing list