[Mlir-commits] [mlir] [MLIR][Vector] Remove implicit bitcast behavior from vector.extract (PR #186383)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 13 05:36:13 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

Drop the `isCompatibleReturnTypes` override on `ExtractOp` that allowed `vector.extract` to return a `vector<1xT>` when the natural inferred return type is scalar `T` (and vice versa). Switch the op from `InferTypeOpAdaptorWithIsCompatible` to `InferTypeOpAdaptor` to match.

Update tests in ConvertToSPIRV, VectorToLLVM, and VectorToSPIRV to remove uses of this implicit scalar↔single-element-vector bitcast.

---
Full diff: https://github.com/llvm/llvm-project/pull/186383.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-13) 
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+2-4) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+4-18) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..f8ece8b734c42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -664,7 +664,7 @@ def Vector_ExtractOp :
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     InferTypeOpAdaptorWithIsCompatible]> {
+     InferTypeOpAdaptor]> {
   let summary = "extract operation";
   let description = [{
     Extracts an (n − k)-D result sub-vector from an n-D source vector at a
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 927d35342cfdd..73632875ca9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1354,19 +1354,6 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
   return success();
 }
 
-bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
-  // Allow extracting 1-element vectors instead of scalars.
-  auto isCompatible = [](TypeRange l, TypeRange r) {
-    auto vectorType = llvm::dyn_cast<VectorType>(l.front());
-    return vectorType && vectorType.getShape().equals({1}) &&
-           vectorType.getElementType() == r.front();
-  };
-  if (l.size() == 1 && r.size() == 1 &&
-      (isCompatible(l, r) || isCompatible(r, l)))
-    return true;
-  return l == r;
-}
-
 LogicalResult vector::ExtractOp::verify() {
   if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
     if (resTy.getRank() == 0)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index cd8cfc8736915..dfe0f9318a942 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -2,12 +2,10 @@
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
   %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
-  return %0, %1: vector<1xf32>, f32
+  return %1: f32
 }
 
 // -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 49c55f5b54496..782bfe029faa3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -300,29 +300,15 @@ func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f
 
 // -----
 
-func.func @extract_vec_1e_from_vec_1d_f32(%arg0: vector<16xf32>) -> vector<1xf32> {
-  %0 = vector.extract %arg0[15]: vector<1xf32> from vector<16xf32>
-  return %0 : vector<1xf32>
-}
-// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32(
-//  CHECK-SAME:   %[[A:.*]]: vector<16xf32>)
-//       CHECK:   %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
-//       CHECK:   %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<16xf32>
-//       CHECK:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-//       CHECK:   return %[[T2]] : vector<1xf32>
-
-// -----
-
-func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> vector<1xf32> {
-  %0 = vector.extract %arg0[15]: vector<1xf32> from vector<[16]xf32>
-  return %0 : vector<1xf32>
+func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+  return %0 : f32
 }
 // CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32_scalable(
 //  CHECK-SAME:   %[[A:.*]]: vector<[16]xf32>)
 //       CHECK:   %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
 //       CHECK:   %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<[16]xf32>
-//       CHECK:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
-//       CHECK:   return %[[T2]] : vector<1xf32>
+//       CHECK:   return %[[T1]] : f32
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c3688e0657d4b..c399250151261 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -165,12 +165,10 @@ func.func @broadcast_index(%a: index) -> vector<4xindex> {
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (f32) {
   %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
-  return %0, %1: vector<1xf32>, f32
+  return %1: f32
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/186383


More information about the Mlir-commits mailing list