[Mlir-commits] [mlir] [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (PR #186383)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 05:36:13 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
Drop the `isCompatibleReturnTypes` override on `ExtractOp` that allowed `vector.extract` to return a `vector<1xT>` when the natural inferred return type is scalar `T` (and vice versa). Switch the op from `InferTypeOpAdaptorWithIsCompatible` to `InferTypeOpAdaptor` to match.
Update tests in ConvertToSPIRV, VectorToLLVM, and VectorToSPIRV to remove uses of this implicit scalar↔single-element-vector bitcast.
---
Full diff: https://github.com/llvm/llvm-project/pull/186383.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-13)
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+2-4)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+4-18)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..f8ece8b734c42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -664,7 +664,7 @@ def Vector_ExtractOp :
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
- InferTypeOpAdaptorWithIsCompatible]> {
+ InferTypeOpAdaptor]> {
let summary = "extract operation";
let description = [{
Extracts an (n − k)-D result sub-vector from an n-D source vector at a
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 927d35342cfdd..73632875ca9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1354,19 +1354,6 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
return success();
}
-bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
- // Allow extracting 1-element vectors instead of scalars.
- auto isCompatible = [](TypeRange l, TypeRange r) {
- auto vectorType = llvm::dyn_cast<VectorType>(l.front());
- return vectorType && vectorType.getShape().equals({1}) &&
- vectorType.getElementType() == r.front();
- };
- if (l.size() == 1 && r.size() == 1 &&
- (isCompatible(l, r) || isCompatible(r, l)))
- return true;
- return l == r;
-}
-
LogicalResult vector::ExtractOp::verify() {
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
if (resTy.getRank() == 0)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index cd8cfc8736915..dfe0f9318a942 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -2,12 +2,10 @@
// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
- %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
%1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
- return %0, %1: vector<1xf32>, f32
+ return %1: f32
}
// -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 49c55f5b54496..782bfe029faa3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -300,29 +300,15 @@ func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f
// -----
-func.func @extract_vec_1e_from_vec_1d_f32(%arg0: vector<16xf32>) -> vector<1xf32> {
- %0 = vector.extract %arg0[15]: vector<1xf32> from vector<16xf32>
- return %0 : vector<1xf32>
-}
-// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32(
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
-// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<16xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-// CHECK: return %[[T2]] : vector<1xf32>
-
-// -----
-
-func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> vector<1xf32> {
- %0 = vector.extract %arg0[15]: vector<1xf32> from vector<[16]xf32>
- return %0 : vector<1xf32>
+func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
+ %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+ return %0 : f32
}
// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32_scalable(
// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
// CHECK: %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<[16]xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-// CHECK: return %[[T2]] : vector<1xf32>
+// CHECK: return %[[T1]] : f32
// -----
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c3688e0657d4b..c399250151261 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -165,12 +165,10 @@ func.func @broadcast_index(%a: index) -> vector<4xindex> {
// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
- %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
%1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
- return %0, %1: vector<1xf32>, f32
+ return %1: f32
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/186383
More information about the Mlir-commits
mailing list