[Mlir-commits] [mlir] 3330ca9 - [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (#186383)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 04:50:55 PDT 2026
Author: Mehdi Amini
Date: 2026-03-23T12:50:50+01:00
New Revision: 3330ca954e466a09ce20e7aa1dcb71909eba6081
URL: https://github.com/llvm/llvm-project/commit/3330ca954e466a09ce20e7aa1dcb71909eba6081
DIFF: https://github.com/llvm/llvm-project/commit/3330ca954e466a09ce20e7aa1dcb71909eba6081.diff
LOG: [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (#186383)
Drop the `isCompatibleReturnTypes` override on `ExtractOp` that allowed
`vector.extract` to return a `vector<1xT>` when the natural inferred
return type is scalar `T` (and vice versa). Switch the op from
`InferTypeOpAdaptorWithIsCompatible` to `InferTypeOpAdaptor` to match.
RFC:
https://discourse.llvm.org/t/rfc-remove-implicit-bitcast-behavior-of-vector-extract/90178
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/ConvertToSPIRV/vector.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..f8ece8b734c42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -664,7 +664,7 @@ def Vector_ExtractOp :
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
- InferTypeOpAdaptorWithIsCompatible]> {
+ InferTypeOpAdaptor]> {
let summary = "extract operation";
let description = [{
Extracts an (n − k)-D result sub-vector from an n-D source vector at a
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 927d35342cfdd..73632875ca9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1354,19 +1354,6 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
return success();
}
-bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
- // Allow extracting 1-element vectors instead of scalars.
- auto isCompatible = [](TypeRange l, TypeRange r) {
- auto vectorType = llvm::dyn_cast<VectorType>(l.front());
- return vectorType && vectorType.getShape().equals({1}) &&
- vectorType.getElementType() == r.front();
- };
- if (l.size() == 1 && r.size() == 1 &&
- (isCompatible(l, r) || isCompatible(r, l)))
- return true;
- return l == r;
-}
-
LogicalResult vector::ExtractOp::verify() {
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
if (resTy.getRank() == 0)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index cd8cfc8736915..dfe0f9318a942 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -2,12 +2,10 @@
// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
- %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
%1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
- return %0, %1: vector<1xf32>, f32
+ return %1: f32
}
// -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 076209cbc7a4c..d570d46e11b4a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -300,29 +300,15 @@ func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f
// -----
-func.func @extract_vec_1e_from_vec_1d_f32(%arg0: vector<16xf32>) -> vector<1xf32> {
- %0 = vector.extract %arg0[15]: vector<1xf32> from vector<16xf32>
- return %0 : vector<1xf32>
-}
-// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32(
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
-// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<16xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-// CHECK: return %[[T2]] : vector<1xf32>
-
-// -----
-
-func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> vector<1xf32> {
- %0 = vector.extract %arg0[15]: vector<1xf32> from vector<[16]xf32>
- return %0 : vector<1xf32>
+func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
+ %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+ return %0 : f32
}
// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32_scalable(
// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
// CHECK: %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<[16]xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-// CHECK: return %[[T2]] : vector<1xf32>
+// CHECK: return %[[T1]] : f32
// -----
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c3688e0657d4b..c399250151261 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -165,12 +165,10 @@ func.func @broadcast_index(%a: index) -> vector<4xindex> {
// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
- %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
%1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
- return %0, %1: vector<1xf32>, f32
+ return %1: f32
}
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d8e08c8b2a850..8f8429e5844d6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -169,6 +169,26 @@ func.func @extract_0d_result(%arg0: vector<f32>) {
// -----
+// Extracting a scalar position from a 1-D vector must return a scalar, not a
+// single-element vector (implicit bitcast is not allowed).
+func.func @extract_scalar_as_single_element_vector(%arg0: vector<2xf32>) {
+ // expected-error at +2 {{'vector.extract' op failed to infer returned types}}
+ // expected-error at +1 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<1xf32>'}}
+ %0 = vector.extract %arg0[0] : vector<1xf32> from vector<2xf32>
+}
+
+// -----
+
+// Extracting a single-element sub-vector from an n-D vector must return the
+// inferred vector type, not a scalar (implicit bitcast is not allowed).
+func.func @extract_subvec_as_scalar(%arg0: vector<3x1xf32>) {
+ // expected-error at +2 {{'vector.extract' op failed to infer returned types}}
+ // expected-error at +1 {{'vector.extract' op inferred type(s) 'vector<1xf32>' are incompatible with return type(s) of operation 'f32'}}
+ %0 = vector.extract %arg0[0] : f32 from vector<3x1xf32>
+}
+
+// -----
+
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}}
%1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
More information about the Mlir-commits
mailing list