[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 15 06:11:28 PST 2023
================
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ auto maskOp = transferReadOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+ auto maskType = maskOp.getType();
+ auto reducedMaskType = trimUnitDims(maskType);
+ if (reducedMaskType.getRank() == maskType.getRank())
+ return failure();
+ SmallVector<Value> maskOperands;
+ for (auto [dim, dimIsScalable, maskOperand] :
+ llvm::zip(maskType.getShape(), maskType.getScalableDims(),
----------------
nicolasvasilache wrote:
`zip_equal` wherever possible please
https://github.com/llvm/llvm-project/pull/72142
More information about the Mlir-commits
mailing list