[Mlir-commits] [mlir] [mlir][Vector] Support scalar `vector.extract` in VectorLinearize (PR #147440)

Diego Caballero llvmlistbot at llvm.org
Mon Jul 7 17:57:50 PDT 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/147440

It generates a linearized version of the `vector.extract` for scalar cases.

>From 781f0cac3296d5f0d813f0054eab005ed8df783d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 8 Jul 2025 00:54:54 +0000
Subject: [PATCH] [mlir][Vector] Support scalar `vector.extract` in
 VectorLinearize

Generate a linearized version of the `vector.extract` for these cases.
---
 .../Vector/Transforms/VectorLinearize.cpp     | 49 ++++++++++++++-----
 mlir/test/Dialect/Vector/linearize.mlir       | 26 +++++-----
 2 files changed, 49 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7cac1cbafdd64..8b232aafbca9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
   }
 };
 
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-///   vector.extract %source [ position ]
-/// is converted to :
-///   %source_1d = vector.shape_cast %source
-///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
+/// This pattern linearizes `vector.extract` operations. It generates a 1-D
+/// version of the `vector.extract` operation when extracting a scalar from a
+/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
+/// subvector from a larger vector.
+///
+/// Example #1:
+///
+///     %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
+///     %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
+///                                 24, 25, 26, 27, 28, 29, 30, 31] :
+///            vector<32xf32>, vector<32xf32>
+///     %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
+///
+/// Example #2:
+///
+///     %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
+///     %1 = vector.extract %0[6] : i32 from vector<8xi32>
+///
 struct LinearizeVectorExtract final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Skip if result is not a vector type
-    if (!isa<VectorType>(extractOp.getType()))
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalar extract not supported");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(dstTy && "expected 1-D vector type");
 
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
       linearizedOffset += offsets[i] * size;
     }
 
+    if (!isa<VectorType>(extractOp.getType())) {
+      // Scalar case: generate a 1-D extract.
+      Value result = rewriter.createOrFold<vector::ExtractOp>(
+          extractOp.getLoc(), adaptor.getVector(), linearizedOffset);
+      rewriter.replaceOp(extractOp, result);
+      return success();
+    }
+
+    // Vector case: generate a shuffle.
+
     llvm::SmallVector<int64_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), linearizedOffset);
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 894171500d9d6..cbc15f34918f6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 
 // -----
 
+// CHECK-LABEL: test_vector_extract_scalar
+// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
+func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
+
+  // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
+  // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
+  // CHECK: return %[[EXTRACT_1D]] : i32
+  %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+  return %0 : i32
+}
+
+// -----
+
 // CHECK-LABEL: test_vector_extract
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 
 // -----
 
-// CHECK-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar(%idx : index) {
-  %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
-
-  // CHECK-NOT: vector.shuffle
-  // CHECK:     vector.extract
-  // CHECK-NOT: vector.shuffle
-  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
-  return
-}
-
-// -----
-
 // CHECK-LABEL: test_vector_bitcast
 // CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
 func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {



More information about the Mlir-commits mailing list