[Mlir-commits] [mlir] [mlir][Vector] Fix crash in drop unit dims (PR #87104)

Diego Caballero llvmlistbot at llvm.org
Fri Mar 29 11:59:12 PDT 2024


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/87104

An `arich.select` may have a scalar condition and true/false vector values.

>From 2b0d4aa5787b1d933dd64f5704b901bc7f4701b8 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 29 Mar 2024 18:44:29 +0000
Subject: [PATCH] [mlir][Vector] Fix crash in drop unit dims

An `arich.select` may have a scalar condition and true/false vector
values.
---
 .../Vector/Transforms/VectorTransforms.cpp    | 10 +++--
 .../vector-dropleadunitdim-transforms.mlir    | 42 ++++++++++++++++++-
 2 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6f6b6dcdad2006..69c497264fd1ed 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1643,10 +1643,12 @@ struct DropUnitDimFromElementwiseOps final
     if (!resultVectorType)
       return failure();
 
-    // Check the pre-conditions. For `Elementwise` Ops all operands are
-    // guaranteed to have identical shapes and it suffices to only check the
-    // first one.
-    auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+    // Check the operand pre-conditions. For `Elementwise` ops all operands are
+    // guaranteed to have identical shapes (with some exceptions such as
+    // `arith.select`) and it suffices to only check one of them.
+    auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
+    if (!sourceVectorType)
+      return failure();
     if (sourceVectorType.getRank() < 2)
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 4ba51c5953d13c..3a120a56056cad 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -40,7 +40,7 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
 // CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
 // CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
 // CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
-// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 
+// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 // CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
 // CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
 // CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
@@ -76,7 +76,7 @@ func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<
 // CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
 // CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
 // CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
-// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
 // CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
 // CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
@@ -472,6 +472,8 @@ func.func @cast_away_elementwise_leading_one_dims(
   return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
 //  CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -483,6 +485,8 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
   return %0: vector<1x1x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_scalable(
 // CHECK-SAME:    %[[S:.*]]: f32,
 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
@@ -495,6 +499,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector
   return %0: vector<1x1x[4]xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
 // CHECK-SAME:    %[[S:.*]]: f32,
 // CHECK-SAME:    %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
@@ -507,6 +513,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %
   return %0: vector<1x[1]x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
 //  CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
@@ -516,6 +524,8 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
   return %0: vector<1x1x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank1_scalable(
 // CHECK-SAME:    %[[S:.*]]: vector<[4]xf32>,
 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
@@ -526,6 +536,8 @@ func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>,
   return %0: vector<1x1x[4]xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
@@ -536,6 +548,8 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
   return %0: vector<1x1x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_scalable(
 // CHECK-SAME:    %[[S:.*]]: vector<1x[4]xf32>,
 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
@@ -547,6 +561,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32
   return %0: vector<1x1x[4]xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
 //       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
@@ -559,6 +575,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
   return %0: vector<1x2x1x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
 // CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
 // CHECK-SAME:      %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
@@ -572,6 +590,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<
   return %0: vector<1x2x1x[4]xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
@@ -582,6 +602,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
   return %0: vector<8x1x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
 // CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
 // CHECK-SAME:      %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
@@ -593,6 +615,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x
   return %0: vector<8x1x[4]xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
 //       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1>
@@ -605,6 +629,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
   return %0: vector<1x1x8x1x8xi1>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
 // CHECK-SAME:      %[[S:.*]]: vector<1x[8]xi1>,
 // CHECK-SAME:      %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
@@ -618,6 +644,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
   return %0: vector<1x1x8x1x[8]xi1>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
 // CHECK:           %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
@@ -626,3 +654,13 @@ func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
   %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
   return %0: vector<1x1x8x2x1xi1>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @drop_unit_dims_scalar_cond_select(
+// CHECK:           arith.select {{.*}} : vector<16xi1>
+func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> {
+  %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
+  return %sel : vector<1x16xi1>
+}
+



More information about the Mlir-commits mailing list