[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 15 06:11:28 PST 2023
================
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ auto maskOp = transferReadOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
----------------
nicolasvasilache wrote:
We now have quite a few failure cases that are non trivial.
How about returning some informative messages with `rewriter.notifyMatchFailure(...)` ?
https://github.com/llvm/llvm-project/pull/72142
More information about the Mlir-commits
mailing list