[Mlir-commits] [mlir] [mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering (PR #124861)
Diego Caballero
llvmlistbot at llvm.org
Tue Jan 28 16:10:07 PST 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/124861
It looks like scalable `vector.insert/extractslice` ops made their way through lowering patterns that generate `vector.shuffle` ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified.
This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.
>From a25b5fe583807e7d393f15dc9f469e4cf051f58e Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 28 Jan 2025 15:58:18 -0800
Subject: [PATCH] [mlir][Vector] Fix scalable Insert/ExtractSlice lowering
It looks like scalable `vector.insert/extractslice` ops made their way
through lowering patterns that generate `vector.shuffle`` ops. I'm not
sure why this wasn't caught by the verifier, probably because the shuffle
op was folded into something else as part of the same rewrite and the IR
wasn't verified.
This PR fixes the issue by preventing scalable vector.insert/extractslice
ops to be lowered to vector shuffles. Instead, they are now lowered to a
sequence of insert/extractelement ops using an existing patter.
---
...sertExtractStridedSliceRewritePatterns.cpp | 21 ++++++++++++++++++-
.../VectorToLLVM/vector-to-llvm.mlir | 15 ++++++-------
2 files changed, 28 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f13..2c32634544b90b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
+ int64_t srcRank = srcType.getRank();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+ return failure();
if (op.getOffsets().getValue().empty())
return failure();
- int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
+ auto srcType = op.getSourceVectorType();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if (dstType.isScalable() || srcType.isScalable())
+ return failure();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
+ // Generate chains of extract/insert ops for scalable vectors only as they
+ // can't be lowered to vector shuffles.
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ patterns,
+ /*controlFn=*/
+ [](ExtractStridedSliceOp op) {
+ return op.getType().isScalable() ||
+ op.getSourceVectorType().isScalable();
+ },
+ benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1a..7df6defc0f202f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-// CHECK: return %[[T5]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK: return %[[RES]]
// -----
More information about the Mlir-commits
mailing list