[Mlir-commits] [mlir] 3330ca9 - [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (#186383)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 04:50:55 PDT 2026


Author: Mehdi Amini
Date: 2026-03-23T12:50:50+01:00
New Revision: 3330ca954e466a09ce20e7aa1dcb71909eba6081

URL: https://github.com/llvm/llvm-project/commit/3330ca954e466a09ce20e7aa1dcb71909eba6081
DIFF: https://github.com/llvm/llvm-project/commit/3330ca954e466a09ce20e7aa1dcb71909eba6081.diff

LOG: [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (#186383)

Drop the `isCompatibleReturnTypes` override on `ExtractOp` that allowed
`vector.extract` to return a `vector<1xT>` when the natural inferred
return type is scalar `T` (and vice versa). Switch the op from
`InferTypeOpAdaptorWithIsCompatible` to `InferTypeOpAdaptor` to match.

RFC:
https://discourse.llvm.org/t/rfc-remove-implicit-bitcast-behavior-of-vector-extract/90178

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/ConvertToSPIRV/vector.mlir
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..f8ece8b734c42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -664,7 +664,7 @@ def Vector_ExtractOp :
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     InferTypeOpAdaptorWithIsCompatible]> {
+     InferTypeOpAdaptor]> {
   let summary = "extract operation";
   let description = [{
     Extracts an (n − k)-D result sub-vector from an n-D source vector at a

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 927d35342cfdd..73632875ca9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1354,19 +1354,6 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
   return success();
 }
 
-bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
-  // Allow extracting 1-element vectors instead of scalars.
-  auto isCompatible = [](TypeRange l, TypeRange r) {
-    auto vectorType = llvm::dyn_cast<VectorType>(l.front());
-    return vectorType && vectorType.getShape().equals({1}) &&
-           vectorType.getElementType() == r.front();
-  };
-  if (l.size() == 1 && r.size() == 1 &&
-      (isCompatible(l, r) || isCompatible(r, l)))
-    return true;
-  return l == r;
-}
-
 LogicalResult vector::ExtractOp::verify() {
   if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
     if (resTy.getRank() == 0)

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index cd8cfc8736915..dfe0f9318a942 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -2,12 +2,10 @@
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
   %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
-  return %0, %1: vector<1xf32>, f32
+  return %1: f32
 }
 
 // -----

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 076209cbc7a4c..d570d46e11b4a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -300,29 +300,15 @@ func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f
 
 // -----
 
-func.func @extract_vec_1e_from_vec_1d_f32(%arg0: vector<16xf32>) -> vector<1xf32> {
-  %0 = vector.extract %arg0[15]: vector<1xf32> from vector<16xf32>
-  return %0 : vector<1xf32>
-}
-// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32(
-//  CHECK-SAME:   %[[A:.*]]: vector<16xf32>)
-//       CHECK:   %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
-//       CHECK:   %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<16xf32>
-//       CHECK:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-//       CHECK:   return %[[T2]] : vector<1xf32>
-
-// -----
-
-func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> vector<1xf32> {
-  %0 = vector.extract %arg0[15]: vector<1xf32> from vector<[16]xf32>
-  return %0 : vector<1xf32>
+func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+  return %0 : f32
 }
 // CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32_scalable(
 //  CHECK-SAME:   %[[A:.*]]: vector<[16]xf32>)
 //       CHECK:   %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
 //       CHECK:   %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<[16]xf32>
-//       CHECK:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-//       CHECK:   return %[[T2]] : vector<1xf32>
+//       CHECK:   return %[[T1]] : f32
 
 // -----
 

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c3688e0657d4b..c399250151261 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -165,12 +165,10 @@ func.func @broadcast_index(%a: index) -> vector<4xindex> {
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
   %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
-  return %0, %1: vector<1xf32>, f32
+  return %1: f32
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d8e08c8b2a850..8f8429e5844d6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -169,6 +169,26 @@ func.func @extract_0d_result(%arg0: vector<f32>) {
 
 // -----
 
+// Extracting a scalar position from a 1-D vector must return a scalar, not a
+// single-element vector (implicit bitcast is not allowed).
+func.func @extract_scalar_as_single_element_vector(%arg0: vector<2xf32>) {
+  // expected-error at +2 {{'vector.extract' op failed to infer returned types}}
+  // expected-error at +1 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<1xf32>'}}
+  %0 = vector.extract %arg0[0] : vector<1xf32> from vector<2xf32>
+}
+
+// -----
+
+// Extracting a single-element sub-vector from an n-D vector must return the
+// inferred vector type, not a scalar (implicit bitcast is not allowed).
+func.func @extract_subvec_as_scalar(%arg0: vector<3x1xf32>) {
+  // expected-error at +2 {{'vector.extract' op failed to infer returned types}}
+  // expected-error at +1 {{'vector.extract' op inferred type(s) 'vector<1xf32>' are incompatible with return type(s) of operation 'f32'}}
+  %0 = vector.extract %arg0[0] : f32 from vector<3x1xf32>
+}
+
+// -----
+
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}}
   %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>


        


More information about the Mlir-commits mailing list