[Mlir-commits] [mlir] [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) (PR #114559)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 4 12:53:03 PST 2024
================
@@ -4698,6 +4698,114 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
+bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+ // InsertSliceOp has its own logic about folding tensor.cast ops.
+ if (isa<InsertSliceOp>(op.getOperation()))
+ return false;
+
+ // Exclude DPS ops that are also LoopLike from this interface as they
+ // might need special handling of attached regions.
+ if (isa<LoopLikeOpInterface>(op.getOperation()))
+ return false;
+
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (llvm::isa<BlockArgument>(opOperand.get()))
+ return false;
+ auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+
+ return hasTensorCastOperand;
+}
+
+static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
+ SmallVector<Type> &newResTy) {
+ SmallVector<Value> newOperands;
+ newOperands.reserve(op->getNumOperands());
+
+ // Assumes that the result has dpsInits followed by nonDpsInits.
+ int64_t dpsInitIdx = 0;
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
+ newResTy[dpsInitIdx++] = newOperands.back().getType();
+ }
+ return newOperands;
+}
+
+/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
+/// ```
+struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
+ .getShape()
+ .take_back(op.getMixedTiles().size()),
+ op.getMixedTiles())) {
----------------
Max191 wrote:
nit:
```suggestion
for (auto [shape, innerTile] : llvm::zip_equal(cast<ShapedType>(newResultTypes[0])
.getShape()
.take_back(op.getMixedTiles().size()),
op.getMixedTiles())) {
```
https://github.com/llvm/llvm-project/pull/114559
More information about the Mlir-commits
mailing list