[Mlir-commits] [mlir] [mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering (PR #124861)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 28 16:10:40 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
It looks like scalable `vector.insert/extractslice` ops made their way through lowering patterns that generate `vector.shuffle` ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified.
This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.
---
Full diff: https://github.com/llvm/llvm-project/pull/124861.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+20-1)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+8-7)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f1..2c32634544b90b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
+ int64_t srcRank = srcType.getRank();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+ return failure();
if (op.getOffsets().getValue().empty())
return failure();
- int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
+ auto srcType = op.getSourceVectorType();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if (dstType.isScalable() || srcType.isScalable())
+ return failure();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
+ // Generate chains of extract/insert ops for scalable vectors only as they
+ // can't be lowered to vector shuffles.
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ patterns,
+ /*controlFn=*/
+ [](ExtractStridedSliceOp op) {
+ return op.getType().isScalable() ||
+ op.getSourceVectorType().isScalable();
+ },
+ benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1..7df6defc0f202f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-// CHECK: return %[[T5]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK: return %[[RES]]
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/124861
More information about the Mlir-commits
mailing list