[Mlir-commits] [mlir] [mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering (PR #124861)
Diego Caballero
llvmlistbot at llvm.org
Fri Jan 31 14:13:16 PST 2025
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/124861
>From 298fff777b00be59fb2fa45085d8f892b165eafd Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 28 Jan 2025 15:58:18 -0800
Subject: [PATCH 1/2] [mlir][Vector] Fix scalable Insert/ExtractSlice lowering
It looks like scalable `vector.insert/extractslice` ops made their way
through lowering patterns that generate `vector.shuffle`` ops. I'm not
sure why this wasn't caught by the verifier, probably because the shuffle
op was folded into something else as part of the same rewrite and the IR
wasn't verified.
This PR fixes the issue by preventing scalable vector.insert/extractslice
ops to be lowered to vector shuffles. Instead, they are now lowered to a
sequence of insert/extractelement ops using an existing patter.
---
...sertExtractStridedSliceRewritePatterns.cpp | 21 ++++++++++++++++++-
.../VectorToLLVM/vector-to-llvm.mlir | 15 ++++++-------
2 files changed, 28 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f1..2c32634544b90b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
+ int64_t srcRank = srcType.getRank();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+ return failure();
if (op.getOffsets().getValue().empty())
return failure();
- int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
+ auto srcType = op.getSourceVectorType();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if (dstType.isScalable() || srcType.isScalable())
+ return failure();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
+ // Generate chains of extract/insert ops for scalable vectors only as they
+ // can't be lowered to vector shuffles.
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ patterns,
+ /*controlFn=*/
+ [](ExtractStridedSliceOp op) {
+ return op.getType().isScalable() ||
+ op.getSourceVectorType().isScalable();
+ },
+ benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1..7df6defc0f202f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-// CHECK: return %[[T5]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK: return %[[RES]]
// -----
>From 2dfe220a09f8e0abbf1bee860a1d0f12a3692842 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 31 Jan 2025 14:09:21 -0800
Subject: [PATCH 2/2] Add TODO.
---
.../VectorInsertExtractStridedSliceRewritePatterns.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 2c32634544b90b..82a985c9e58248 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -318,6 +318,8 @@ class DecomposeNDExtractStridedSlice
}
};
+// TODO: Make sure these `populate*` patterns are tested in isolation.
+
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
More information about the Mlir-commits
mailing list