[Mlir-commits] [mlir] [mlir][Vector] Fix crash in drop unit dims (PR #87104)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 29 11:59:41 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
An `arich.select` may have a scalar condition and true/false vector values.
---
Full diff: https://github.com/llvm/llvm-project/pull/87104.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+6-4)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+40-2)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6f6b6dcdad2006..69c497264fd1ed 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1643,10 +1643,12 @@ struct DropUnitDimFromElementwiseOps final
if (!resultVectorType)
return failure();
- // Check the pre-conditions. For `Elementwise` Ops all operands are
- // guaranteed to have identical shapes and it suffices to only check the
- // first one.
- auto sourceVectorType = cast<VectorType>(op->getOperands()[0].getType());
+ // Check the operand pre-conditions. For `Elementwise` ops all operands are
+ // guaranteed to have identical shapes (with some exceptions such as
+ // `arith.select`) and it suffices to only check one of them.
+ auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
+ if (!sourceVectorType)
+ return failure();
if (sourceVectorType.getRank() < 2)
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 4ba51c5953d13c..3a120a56056cad 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -40,7 +40,7 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
-// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
@@ -76,7 +76,7 @@ func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<
// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
-// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>
@@ -472,6 +472,8 @@ func.func @cast_away_elementwise_leading_one_dims(
return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
}
+// -----
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -483,6 +485,8 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
return %0: vector<1x1x4xf32>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable(
// CHECK-SAME: %[[S:.*]]: f32,
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
@@ -495,6 +499,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector
return %0: vector<1x1x[4]xf32>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
// CHECK-SAME: %[[S:.*]]: f32,
// CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
@@ -507,6 +513,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %
return %0: vector<1x[1]x4xf32>
}
+// -----
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
@@ -516,6 +524,8 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
return %0: vector<1x1x4xf32>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable(
// CHECK-SAME: %[[S:.*]]: vector<[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
@@ -526,6 +536,8 @@ func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>,
return %0: vector<1x1x[4]xf32>
}
+// -----
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
@@ -536,6 +548,8 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
return %0: vector<1x1x4xf32>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
@@ -547,6 +561,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32
return %0: vector<1x1x[4]xf32>
}
+// -----
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
@@ -559,6 +575,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
return %0: vector<1x2x1x4xf32>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
@@ -572,6 +590,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<
return %0: vector<1x2x1x[4]xf32>
}
+// -----
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
@@ -582,6 +602,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
return %0: vector<8x1x4xf32>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
@@ -593,6 +615,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x
return %0: vector<8x1x[4]xf32>
}
+// -----
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1>
@@ -605,6 +629,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
return %0: vector<1x1x8x1x8xi1>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>,
// CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
@@ -618,6 +644,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
return %0: vector<1x1x8x1x[8]xi1>
}
+// -----
+
// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
@@ -626,3 +654,13 @@ func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
%0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
return %0: vector<1x1x8x2x1xi1>
}
+
+// -----
+
+// CHECK-LABEL: func.func @drop_unit_dims_scalar_cond_select(
+// CHECK: arith.select {{.*}} : vector<16xi1>
+func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> {
+ %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
+ return %sel : vector<1x16xi1>
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/87104
More information about the Mlir-commits
mailing list