[Mlir-commits] [mlir] [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) (PR #114559)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 4 12:53:03 PST 2024


================
@@ -4698,6 +4698,114 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 // Common Canonicalizers and Folders.
 //===----------------------------------------------------------------------===//
+bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+  // InsertSliceOp has its own logic about folding tensor.cast ops.
+  if (isa<InsertSliceOp>(op.getOperation()))
+    return false;
+
+  // Exclude DPS ops that are also LoopLike from this interface as they
+  // might need special handling of attached regions.
+  if (isa<LoopLikeOpInterface>(op.getOperation()))
+    return false;
+
+  // If no operand comes from a tensor::CastOp and can be folded then fail.
+  bool hasTensorCastOperand =
+      llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+        if (llvm::isa<BlockArgument>(opOperand.get()))
+          return false;
+        auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+        return castOp && canFoldIntoConsumerOp(castOp);
+      });
+
+  return hasTensorCastOperand;
+}
+
+static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
+                                         SmallVector<Type> &newResTy) {
+  SmallVector<Value> newOperands;
+  newOperands.reserve(op->getNumOperands());
+
+  // Assumes that the result has dpsInits followed by nonDpsInits.
+  int64_t dpsInitIdx = 0;
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    bool fold = canFoldIntoConsumerOp(tensorCastOp);
+    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+    if (op.isDpsInit(&opOperand) &&
+        !llvm::isa<MemRefType>(newOperands.back().getType()))
+      newResTy[dpsInitIdx++] = newOperands.back().getType();
+  }
+  return newOperands;
+}
+
+/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
+/// ```
+struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!foldTensorCastPrecondition(op))
+      return failure();
+
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+
+    // Get the updated mixed-tile-sizes attribute.
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
+                                 .getShape()
+                                 .take_back(op.getMixedTiles().size()),
+                             op.getMixedTiles())) {
+      int64_t shape = std::get<0>(it);
+      if (shape == ShapedType::kDynamic) {
+        newMixedTileSizes.push_back(std::get<1>(it));
+        continue;
+      }
+
+      if (Attribute attr =
+              llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+        // Already a constant
+        newMixedTileSizes.push_back(std::get<1>(it));
+      } else {
+        auto tileSize = getConstantIntValue(std::get<1>(it));
+        assert(tileSize == shape && "tile size and dim size don't match!");
----------------
Max191 wrote:

nit: I think it is safe to assume the size should be constant here, so you can use `int64_t`. This makes the types more clear.
```suggestion
        int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
        assert(tileSize == shape && "tile size and dim size don't match!");
```

https://github.com/llvm/llvm-project/pull/114559


More information about the Mlir-commits mailing list