[Mlir-commits] [mlir] [mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts (PR #113828)

Kunwar Grover llvmlistbot at llvm.org
Mon Oct 28 08:06:47 PDT 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/113828

>From c56479dbb9e746057c58fb640e6504152c8990bc Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 27 Oct 2024 18:14:07 +0000
Subject: [PATCH 1/3] [mlir][Vector] Fix vector.insert folder for scalar to 0-d
 inserts

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  8 ++++----
 mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++++
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..03d2409f42c524 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2951,11 +2951,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertOpConstantFolder>(context);
 }
 
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getNumIndices() == 0)
+  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+  // (type mismatch).
+  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
     return getSource();
   return {};
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d6bc199e601c0..580daa2a13d15e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2745,6 +2745,18 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
+// CHECK-LABEL: func @insert_into_0d_regression(
+//  CHECK-SAME:     %[[v:.*]]: vector<f32>)
+//       CHECK:   %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
+//       CHECK:   return %[[extract]]
+func.func @insert_into_0d_regression(%v: vector<f32>) -> vector<f32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.insert %cst, %v [] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
 // CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>

>From 37654707c6036978d45850786b437a7b25dfbb8b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 28 Oct 2024 11:50:02 +0000
Subject: [PATCH 2/3] Address comments

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 2 +-
 mlir/test/Dialect/Vector/canonicalize.mlir | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 03d2409f42c524..1853ae04f45d90 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2955,7 +2955,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
-  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
+  if (getNumIndices() == 0 && getSourceType() == getType())
     return getSource();
   return {};
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 580daa2a13d15e..484764f7671c91 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2745,11 +2745,11 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
-// CHECK-LABEL: func @insert_into_0d_regression(
+// CHECK-LABEL: func @insert_no_fold_type_mismatch(
 //  CHECK-SAME:     %[[v:.*]]: vector<f32>)
 //       CHECK:   %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
 //       CHECK:   return %[[extract]]
-func.func @insert_into_0d_regression(%v: vector<f32>) -> vector<f32> {
+func.func @insert_no_fold_type_mismatch(%v: vector<f32>) -> vector<f32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = vector.insert %cst, %v [] : f32 into vector<f32>
   return %0 : vector<f32>

>From 95dec126188363c5c1fad4a80e386920ae4e292b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 28 Oct 2024 15:06:21 +0000
Subject: [PATCH 3/3] Address comments

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 60 +++++++++++++---------
 1 file changed, 37 insertions(+), 23 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 484764f7671c91..c963460e7259fb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -800,6 +800,43 @@ func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vecto
 
 // -----
 
+// CHECK-LABEL: func @extract_no_fold_scalar_to_0d(
+//  CHECK-SAME:     %[[v:.*]]: vector<f32>)
+//       CHECK:   %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
+//       CHECK:   return %[[extract]]
+func.func @extract_no_fold_scalar_to_0d(%v: vector<f32>) -> f32 {
+  %0 = vector.extract %v[] : f32 from vector<f32>
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_fold_same_rank(
+//  CHECK-SAME:     %[[v:.*]]: vector<2x2xf32>)
+//       CHECK:      %[[CST:.+]] = arith.constant
+//  CHECK-SAME:                    : vector<2x2xf32>
+//       CHECK-NOT:  vector.insert
+//       CHECK:   return %[[CST]]
+func.func @insert_fold_same_rank(%v: vector<2x2xf32>) -> vector<2x2xf32> {
+  %cst = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+  %0 = vector.insert %cst, %v [] : vector<2x2xf32> into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_no_fold_scalar_to_0d(
+//  CHECK-SAME:     %[[v:.*]]: vector<f32>)
+//       CHECK:   %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
+//       CHECK:   return %[[extract]]
+func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.insert %cst, %v [] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// -----
+
 // CHECK-LABEL: dont_fold_expand_collapse
 //       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
 //       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -2606,17 +2643,6 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
 
 // -----
 
-// CHECK-LABEL: func @extract_from_0d_regression(
-//  CHECK-SAME:     %[[v:.*]]: vector<f32>)
-//       CHECK:   %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
-//       CHECK:   return %[[extract]]
-func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
-  %0 = vector.extract %v[] : f32 from vector<f32>
-  return %0 : f32
-}
-
-// -----
-
 // CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
 //  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
 func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
@@ -2745,18 +2771,6 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
-// CHECK-LABEL: func @insert_no_fold_type_mismatch(
-//  CHECK-SAME:     %[[v:.*]]: vector<f32>)
-//       CHECK:   %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
-//       CHECK:   return %[[extract]]
-func.func @insert_no_fold_type_mismatch(%v: vector<f32>) -> vector<f32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = vector.insert %cst, %v [] : f32 into vector<f32>
-  return %0 : vector<f32>
-}
-
-// -----
-
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
 // CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>



More information about the Mlir-commits mailing list