[Mlir-commits] [mlir] [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (PR #186383)

Mehdi Amini llvmlistbot at llvm.org
Mon Mar 16 08:13:33 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/186383

>From 88e0c8bc771a1732fbf3a4d05a3c68211dbc3a78 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 13 Mar 2026 05:33:07 -0700
Subject: [PATCH 1/2] [MLIR][Vector] Remove implicit bitcast behavior from
 vector.extract
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Drop the `isCompatibleReturnTypes` override on `ExtractOp` that allowed
`vector.extract` to return a `vector<1xT>` when the natural inferred
return type is scalar `T` (and vice versa). Switch the op from
`InferTypeOpAdaptorWithIsCompatible` to `InferTypeOpAdaptor` to match.

Update tests in ConvertToSPIRV, VectorToLLVM, and VectorToSPIRV to
remove uses of this implicit scalar↔single-element-vector bitcast.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 13 -----------
 .../Conversion/ConvertToSPIRV/vector.mlir     |  6 ++---
 .../vector-to-llvm-interface.mlir             | 22 ++++---------------
 .../VectorToSPIRV/vector-to-spirv.mlir        |  6 ++---
 5 files changed, 9 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..f8ece8b734c42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -664,7 +664,7 @@ def Vector_ExtractOp :
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     InferTypeOpAdaptorWithIsCompatible]> {
+     InferTypeOpAdaptor]> {
   let summary = "extract operation";
   let description = [{
     Extracts an (n − k)-D result sub-vector from an n-D source vector at a
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 927d35342cfdd..73632875ca9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1354,19 +1354,6 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
   return success();
 }
 
-bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
-  // Allow extracting 1-element vectors instead of scalars.
-  auto isCompatible = [](TypeRange l, TypeRange r) {
-    auto vectorType = llvm::dyn_cast<VectorType>(l.front());
-    return vectorType && vectorType.getShape().equals({1}) &&
-           vectorType.getElementType() == r.front();
-  };
-  if (l.size() == 1 && r.size() == 1 &&
-      (isCompatible(l, r) || isCompatible(r, l)))
-    return true;
-  return l == r;
-}
-
 LogicalResult vector::ExtractOp::verify() {
   if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
     if (resTy.getRank() == 0)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index cd8cfc8736915..dfe0f9318a942 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -2,12 +2,10 @@
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
   %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
-  return %0, %1: vector<1xf32>, f32
+  return %1: f32
 }
 
 // -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 49c55f5b54496..782bfe029faa3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -300,29 +300,15 @@ func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f
 
 // -----
 
-func.func @extract_vec_1e_from_vec_1d_f32(%arg0: vector<16xf32>) -> vector<1xf32> {
-  %0 = vector.extract %arg0[15]: vector<1xf32> from vector<16xf32>
-  return %0 : vector<1xf32>
-}
-// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32(
-//  CHECK-SAME:   %[[A:.*]]: vector<16xf32>)
-//       CHECK:   %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
-//       CHECK:   %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<16xf32>
-//       CHECK:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-//       CHECK:   return %[[T2]] : vector<1xf32>
-
-// -----
-
-func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> vector<1xf32> {
-  %0 = vector.extract %arg0[15]: vector<1xf32> from vector<[16]xf32>
-  return %0 : vector<1xf32>
+func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+  return %0 : f32
 }
 // CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32_scalable(
 //  CHECK-SAME:   %[[A:.*]]: vector<[16]xf32>)
 //       CHECK:   %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
 //       CHECK:   %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<[16]xf32>
-//       CHECK:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-//       CHECK:   return %[[T2]] : vector<1xf32>
+//       CHECK:   return %[[T1]] : f32
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c3688e0657d4b..c399250151261 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -165,12 +165,10 @@ func.func @broadcast_index(%a: index) -> vector<4xindex> {
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
   %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
-  return %0, %1: vector<1xf32>, f32
+  return %1: f32
 }
 
 // -----

>From 6ec3da6603515bd5eafa3b13f27fdca6adfb416b Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 13 Mar 2026 06:20:41 -0700
Subject: [PATCH 2/2] [MLIR][Vector] Add invalid tests for vector.extract
 implicit bitcast

Add two test cases to invalid.mlir that verify vector.extract no longer
accepts single-element vector types (vector<1xT>) as the result when the
inferred type is a scalar T or a wider sub-vector.

Assisted-by: Claude Code
---
 mlir/test/Dialect/Vector/invalid.mlir | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3957455ccc76e..9a930526531d9 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -169,6 +169,26 @@ func.func @extract_0d_result(%arg0: vector<f32>) {
 
 // -----
 
+// Extracting a scalar position from a 1-D vector must return a scalar, not a
+// single-element vector (implicit bitcast is not allowed).
+func.func @extract_scalar_as_single_element_vector(%arg0: vector<2xf32>) {
+  // expected-error at +2 {{'vector.extract' op failed to infer returned types}}
+  // expected-error at +1 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<1xf32>'}}
+  %0 = vector.extract %arg0[0] : vector<1xf32> from vector<2xf32>
+}
+
+// -----
+
+// Extracting a single-element sub-vector from an n-D vector must return the
+// inferred vector type, not a scalar (implicit bitcast is not allowed).
+func.func @extract_subvec_as_scalar(%arg0: vector<3x1xf32>) {
+  // expected-error at +2 {{'vector.extract' op failed to infer returned types}}
+  // expected-error at +1 {{'vector.extract' op inferred type(s) 'vector<1xf32>' are incompatible with return type(s) of operation 'f32'}}
+  %0 = vector.extract %arg0[0] : f32 from vector<3x1xf32>
+}
+
+// -----
+
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}}
   %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>



More information about the Mlir-commits mailing list