[Mlir-commits] [mlir] [mlir][Vector] Generate poison vectors in vector.shape_cast lowering (PR #125613)

Diego Caballero llvmlistbot at llvm.org
Mon Feb 3 17:13:05 PST 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/125613

This is the first PR that introduces `ub.poison` vectors as part of a rewrite/conversion pattern in the Vector dialect. It replaces the `arith.constant dense<0>` vector initialization for `vector.insert_slice` ops with a poison vector.

This PR depends on all the previous PRs that introduced support for poison in Vector operations such as `vector.shuffle`, `vector.extract`, `vector.insert`, including ODS, canonicalization and lowering support.

This PR may improve end-to-end compilation time through LLVM, depending on the workloads.

>From d70f6e3346f703392dc74043070f95cc5a888e95 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 3 Feb 2025 16:58:54 -0800
Subject: [PATCH] [mlir][Vector] Generate poison vectors in vector.shape_cast
 lowering

This is the first PR that introduces `ub.poison` vectors as part of a
rewrite/conversion pattern in the Vector dialect. It replaces the
`arith.constant dense<0>` vector initialization for `vector.insert_slice`
ops with a poison vector.

This PR depends on all the previous PRs that introduced support for
poison in Vector operations such as `vector.shuffle`, `vector.extract`,
`vector.insert`, including ODS, canonicalization and lowering support.

This PR may improve end-to-end compilation time through LLVM, depending
on the workloads.
---
 .../Vector/Transforms/LowerVectorShapeCast.cpp | 15 +++++----------
 .../ConvertToSPIRV/vector-unroll.mlir          | 11 +++++++----
 ...ntract-to-matrix-intrinsics-transforms.mlir |  6 +++---
 ...r-shape-cast-lowering-scalable-vectors.mlir | 18 +++++++++---------
 .../vector-shape-cast-lowering-transforms.mlir | 14 +++++++-------
 5 files changed, 31 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 239dc9aa1de6fb..9c1e5fcee91de4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB//IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -73,8 +73,7 @@ class ShapeCastOpNDDownCastRewritePattern
     SmallVector<int64_t> srcIdx(srcRank - 1, 0);
     SmallVector<int64_t> resIdx(resRank, 0);
     int64_t extractSize = sourceVectorType.getShape().back();
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
 
     // Compute the indices of each 1-D vector element of the source extraction
     // and destination slice insertion and generate such instructions.
@@ -129,8 +128,7 @@ class ShapeCastOpNDUpCastRewritePattern
     SmallVector<int64_t> srcIdx(srcRank, 0);
     SmallVector<int64_t> resIdx(resRank - 1, 0);
     int64_t extractSize = resultVectorType.getShape().back();
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
         incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
@@ -184,8 +182,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     // within the source and result shape.
     SmallVector<int64_t> srcIdx(srcRank, 0);
     SmallVector<int64_t> resIdx(resRank, 0);
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; i++) {
       if (i != 0) {
         incIdx(srcIdx, sourceVectorType);
@@ -291,9 +288,7 @@ class ScalableShapeCastOpRewritePattern
     auto extractionVectorType = VectorType::get(
         {minExtractionSize}, sourceVectorType.getElementType(), {true});
 
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     SmallVector<int64_t> srcIdx(srcRank, 0);
     SmallVector<int64_t> resIdx(resRank, 0);
 
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index 043f9422d8790f..f1cc1354d1e3bc 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -83,17 +83,20 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
 // CHECK-LABEL: @transpose
 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
 func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2xi32>
+  // CHECK: %[[CST:.*]] = ub.poison : vector<1x2xi32>
   // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[CST1:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+  // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST1]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
   // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[CST2:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST2]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
   // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[CST3:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+  // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST3]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
   // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
   // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 4867a416e5d144..fd6895c01d78bd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -14,9 +14,9 @@
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//  CHECK-DAG:  %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
-//  CHECK-DAG:  %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32>
-//  CHECK-DAG:  %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+//  CHECK-DAG:  %[[vcst:.*]] = ub.poison : vector<8xf32>
+//  CHECK-DAG:  %[[vcst_0:.*]] = ub.poison : vector<12xf32>
+//  CHECK-DAG:  %[[vcst_1:.*]] = ub.poison : vector<2x3xf32>
 //      CHECK:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<4xf32> from vector<2x4xf32>
 //      CHECK:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
 //      CHECK:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index fde6ce91024464..b4518e57c39ddd 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -7,7 +7,7 @@
 // CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32>
 func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[8]xi32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
   // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
@@ -22,7 +22,7 @@ func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<
 // CHECK-LABEL: i32_1d_to_3d_last_dim_scalable
 // CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32>
 func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<2x1x[4]xi32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32>
@@ -37,7 +37,7 @@ func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x
 // CHECK-LABEL: i8_2d_to_1d_last_dim_scalable
 // CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8>
 func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[32]xi8>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[8]xi8> from vector<4x[8]xi8>
   // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[8]xi8> from vector<4x[8]xi8>
@@ -56,7 +56,7 @@ func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]
 // CHECK-LABEL: i8_1d_to_2d_last_dim_scalable
 // CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8>
 func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<4x[8]xi8>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8>
@@ -75,7 +75,7 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
 // CHECK-LABEL: f32_permute_leading_non_scalable_dims
 // CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
 func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<3x2x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
@@ -99,7 +99,7 @@ func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) ->
 // CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64>
 func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<4x[2]xf64>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf64> from vector<2x2x[2]xf64>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
@@ -109,7 +109,7 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
   // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
   // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64>
   %res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64>
-  // CHECK-NEXT: return %7 : vector<4x[2]xf64>
+  // CHECK-NEXT: return %[[res3:.*]] : vector<4x[2]xf64>
   return %res : vector<4x[2]xf64>
 }
 
@@ -119,7 +119,7 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
 // CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
 func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<6x[2]xf32>
   // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
@@ -146,7 +146,7 @@ func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<
 // CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
 func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<2x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
   // CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<[4]xf32> from vector<2x[4]xf32>
   // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index b4c52d5533116c..ee4fe59424a482 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -22,8 +22,8 @@ func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
 // llvm.matrix operations
 // CHECK-LABEL: func @shape_casts
 func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
-  // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
-  // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[cst22:.*]] = ub.poison : vector<2x2xf32>
+  // CHECK-DAG: %[[cst:.*]] = ub.poison : vector<4xf32>
   // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
   //
   // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
@@ -59,7 +59,7 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>)
 
 // CHECK-LABEL: func @shape_cast_2d2d
 // CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<2x3xf32>
 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
 // CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
 // CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
@@ -81,7 +81,7 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
 
 // CHECK-LABEL: func @shape_cast_3d1d
 // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<6xf32>
 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
 // CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
 // CHECK-SAME:           {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
@@ -100,7 +100,7 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
 
 // CHECK-LABEL: func @shape_cast_1d3d
 // CHECK-SAME: %[[A:.*]]: vector<6xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<2x1x3xf32>
 // CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
 // CHECK-SAME:           {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
 // CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
@@ -116,7 +116,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 
 // CHECK-LABEL:   func.func @shape_cast_0d1d(
 // CHECK-SAME:                               %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:           %[[VAL_1:.*]] = ub.poison : vector<1xf32>
 // CHECK:           %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
 // CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
 // CHECK:           return %[[VAL_3]] : vector<1xf32>
@@ -129,7 +129,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // CHECK-LABEL:   func.func @shape_cast_1d0d(
 // CHECK-SAME:                               %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK:           %[[VAL_1:.*]] = ub.poison : vector<f32>
 // CHECK:           %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32>
 // CHECK:           %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
 // CHECK:           return %[[VAL_3]] : vector<f32>



More information about the Mlir-commits mailing list