[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 15 06:11:28 PST 2023
================
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ auto maskOp = transferReadOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+ auto maskType = maskOp.getType();
+ auto reducedMaskType = trimUnitDims(maskType);
+ if (reducedMaskType.getRank() == maskType.getRank())
+ return failure();
+ SmallVector<Value> maskOperands;
+ for (auto [dim, dimIsScalable, maskOperand] :
+ llvm::zip(maskType.getShape(), maskType.getScalableDims(),
+ createMaskOp.getOperands())) {
+ if (dim == 1 && !dimIsScalable) {
+ // If the mask for the unit dim is not a constant of 1, do nothing.
+ auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constant || (constant.value() != 1))
+ return failure();
+ continue;
----------------
nicolasvasilache wrote:
can we hoist this logic in a helper with a good name? It seems this is deep enough already.
https://github.com/llvm/llvm-project/pull/72142
More information about the Mlir-commits
mailing list