[Mlir-commits] [mlir] [mlir][Vector] Support scalar `vector.extract` in VectorLinearize (PR #147440)
Diego Caballero
llvmlistbot at llvm.org
Mon Jul 7 17:57:50 PDT 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/147440
It generates a linearized version of the `vector.extract` for scalar cases.
>From 781f0cac3296d5f0d813f0054eab005ed8df783d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 8 Jul 2025 00:54:54 +0000
Subject: [PATCH] [mlir][Vector] Support scalar `vector.extract` in
VectorLinearize
Generate a linearized version of the `vector.extract` for these cases.
---
.../Vector/Transforms/VectorLinearize.cpp | 49 ++++++++++++++-----
mlir/test/Dialect/Vector/linearize.mlir | 26 +++++-----
2 files changed, 49 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7cac1cbafdd64..8b232aafbca9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
}
};
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-/// vector.extract %source [ position ]
-/// is converted to :
-/// %source_1d = vector.shape_cast %source
-/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
+/// This pattern linearizes `vector.extract` operations. It generates a 1-D
+/// version of the `vector.extract` operation when extracting a scalar from a
+/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
+/// subvector from a larger vector.
+///
+/// Example #1:
+///
+/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
+/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
+/// 24, 25, 26, 27, 28, 29, 30, 31] :
+/// vector<32xf32>, vector<32xf32>
+/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
+///
+/// Example #2:
+///
+/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
+/// %1 = vector.extract %0[6] : i32 from vector<8xi32>
+///
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Skip if result is not a vector type
- if (!isa<VectorType>(extractOp.getType()))
- return rewriter.notifyMatchFailure(extractOp,
- "scalar extract not supported");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
assert(dstTy && "expected 1-D vector type");
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
linearizedOffset += offsets[i] * size;
}
+ if (!isa<VectorType>(extractOp.getType())) {
+ // Scalar case: generate a 1-D extract.
+ Value result = rewriter.createOrFold<vector::ExtractOp>(
+ extractOp.getLoc(), adaptor.getVector(), linearizedOffset);
+ rewriter.replaceOp(extractOp, result);
+ return success();
+ }
+
+ // Vector case: generate a shuffle.
+
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 894171500d9d6..cbc15f34918f6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
// -----
+// CHECK-LABEL: test_vector_extract_scalar
+// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
+func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
+
+ // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
+ // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
+ // CHECK: return %[[EXTRACT_1D]] : i32
+ %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+ return %0 : i32
+}
+
+// -----
+
// CHECK-LABEL: test_vector_extract
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
-// CHECK-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar(%idx : index) {
- %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
-
- // CHECK-NOT: vector.shuffle
- // CHECK: vector.extract
- // CHECK-NOT: vector.shuffle
- %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
- return
-}
-
-// -----
-
// CHECK-LABEL: test_vector_bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
More information about the Mlir-commits
mailing list