[Mlir-commits] [mlir] [mlir][vector] Fix rewrite pattern API violation in `VectorToSCF` (PR #77909)
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 12 03:46:22 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/77909
>From ecb9d3bb995fc9fee95081a21778a1d880a48231 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 12 Jan 2024 11:30:50 +0000
Subject: [PATCH] [mlir][vector] Fix rewrite pattern API violation in
`VectorToSCF`
A pattern is not allowed to change the IR if it returns "failure". This commit fixes `test/Conversion/VectorToSCF/vector-to-scf.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
```
Processing operation : 'vector.transfer_read'(0x55823a409a60) {
%5 = "vector.transfer_read"(%arg0, %0, %0, %2, %4) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 1>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<?x4xf32>, index, index, f32, vector<[4]x4xi1>) -> vector<[4]x4xf32>
* Pattern (anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion : 'vector.transfer_read -> ()' {
Trying to match "(anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion"
** Insert : 'vector.splat'(0x55823a445640)
"(anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion" result 0
} -> failure : pattern failed to match
LLVM ERROR: pattern returned failure but IR did change
```
---
.../Conversion/VectorToSCF/VectorToSCF.cpp | 29 ++++++++++---------
1 file changed, 16 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index a1aff1ab36a52b..9cb86bde7dc600 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1060,10 +1060,10 @@ struct UnrollTransferReadConversion
setHasBoundedRewriteRecursion();
}
- /// Return the vector into which the newly created TransferReadOp results
- /// are inserted.
- Value getResultVector(TransferReadOp xferOp,
- PatternRewriter &rewriter) const {
+ /// Get or build the vector into which the newly created TransferReadOp
+ /// results are inserted.
+ Value buildResultVector(PatternRewriter &rewriter,
+ TransferReadOp xferOp, ) const {
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
@@ -1098,24 +1098,27 @@ struct UnrollTransferReadConversion
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
- return failure();
+ return rewriter.notifyMatchFailure(
+ xferOp, "vector rank is less or equal to target rank");
if (isTensorOp(xferOp) && !options.lowerTensors)
- return failure();
+ return rewriter.notifyMatchFailure(
+ xferOp, "transfers operating on tensors are excluded");
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
xferOp.getShapedType().getElementType())
- return failure();
-
- auto insertOp = getInsertOp(xferOp);
- auto vec = getResultVector(xferOp, rewriter);
- auto vecType = dyn_cast<VectorType>(vec.getType());
+ return rewriter.notifyMatchFailure(
+ xferOp, "not yet supported: element type mismatch");
auto xferVecType = xferOp.getVectorType();
-
if (xferVecType.getScalableDims()[0]) {
// Cannot unroll a scalable dimension at compile time.
- return failure();
+ return rewriter.notifyMatchFailure(
+ xferOp, "scalable dimensions cannot be unrolled");
}
+ auto insertOp = getInsertOp(xferOp);
+ auto vec = buildResultVector(rewriter, xferOp);
+ auto vecType = dyn_cast<VectorType>(vec.getType());
+
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
int64_t dimSize = xferVecType.getShape()[0];
More information about the Mlir-commits
mailing list