[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 15 06:11:31 PST 2023
================
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ auto maskOp = transferReadOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+ auto maskType = maskOp.getType();
+ auto reducedMaskType = trimUnitDims(maskType);
+ if (reducedMaskType.getRank() == maskType.getRank())
+ return failure();
+ SmallVector<Value> maskOperands;
+ for (auto [dim, dimIsScalable, maskOperand] :
+ llvm::zip(maskType.getShape(), maskType.getScalableDims(),
+ createMaskOp.getOperands())) {
+ if (dim == 1 && !dimIsScalable) {
+ // If the mask for the unit dim is not a constant of 1, do nothing.
+ auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constant || (constant.value() != 1))
+ return failure();
+ continue;
+ }
+ maskOperands.push_back(maskOperand);
+ }
+ maskOp = rewriter.create<vector::CreateMaskOp>(loc, reducedMaskType,
+ maskOperands);
+ }
+
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
- auto reducedVectorType = VectorType::get(
- getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+ SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
+ loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
+ transferReadOp.getPadding(), maskOp,
----------------
nicolasvasilache wrote:
thanks for adding the previously omitted mask!
https://github.com/llvm/llvm-project/pull/72142
More information about the Mlir-commits
mailing list