[Mlir-commits] [mlir] 9b9369e - [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) (#114559)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 5 10:54:56 PST 2024

Author: Andrzej WarzyƄski
Date: 2024-11-05T18:54:52Z
New Revision: 9b9369e0bb0131ba0336d9adb4ef098b6dafc7f4

URL: https://github.com/llvm/llvm-project/commit/9b9369e0bb0131ba0336d9adb4ef098b6dafc7f4
DIFF: https://github.com/llvm/llvm-project/commit/9b9369e0bb0131ba0336d9adb4ef098b6dafc7f4.diff

LOG: [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) (#114559)

Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
    %pack = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
as (note the static trailing dim in the result and dynamic tile
dimension that corresponds to that):
    %res = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x8x1xi32>

This triggers an Op verification failure and is due to the fact that the
folder does not update the inner tile sizes in the pack Op. This PR
addresses that.

Note, supporting other Ops with size-like attributes is left as a TODO.




diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c2d6bc610cd92a..1847066b2d1e36 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
 // Common Canonicalizers and Folders.
+bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
+  // 2. Exclude DPS ops that are also LoopLike from this interface as they
+  // might need special handling of attached regions.
+  if (isa<InsertSliceOp>(op.getOperation()) ||
+      isa<LoopLikeOpInterface>(op.getOperation()))
+    return false;
+  // If no operand comes from a tensor::CastOp and can be folded then fail.
+  bool hasTensorCastOperand =
+      llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+        if (llvm::isa<BlockArgument>(opOperand.get()))
+          return false;
+        auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+        return castOp && canFoldIntoConsumerOp(castOp);
+      });
+  return hasTensorCastOperand;
+static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
+                                         SmallVector<Type> &newResTy) {
+  SmallVector<Value> newOperands;
+  newOperands.reserve(op->getNumOperands());
+  // Assumes that the result has dpsInits followed by nonDpsInits.
+  int64_t dpsInitIdx = 0;
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    bool fold = canFoldIntoConsumerOp(tensorCastOp);
+    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+    if (op.isDpsInit(&opOperand) &&
+        !llvm::isa<MemRefType>(newOperands.back().getType()))
+      newResTy[dpsInitIdx++] = newOperands.back().getType();
+  }
+  return newOperands;
+/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
+/// ```
+/// folds into:
+/// ```mlir
+///   %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
+/// ```
+struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(PackOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!foldTensorCastPrecondition(op))
+      return failure();
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    // Get the updated mixed-tile-sizes attribute.
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
+                                 .getShape()
+                                 .take_back(op.getMixedTiles().size()),
+                             op.getMixedTiles())) {
+      int64_t shape = std::get<0>(it);
+      if (shape == ShapedType::kDynamic) {
+        newMixedTileSizes.push_back(std::get<1>(it));
+        continue;
+      }
+      if (Attribute attr =
+              llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+        // Already a constant
+        newMixedTileSizes.push_back(std::get<1>(it));
+      } else {
+        int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
+        assert(tileSize == shape && "tile size and dim size don't match!");
+        newMixedTileSizes.push_back(
+            (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+      }
+    }
+    // Clone op.
+    PackOp newOp = rewriter.create<PackOp>(
+        op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
+        newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
+    // Replace op.
+    Value oldResult = op.getResult();
+    Value newResult = newOp.getResult();
+    Value replacement = (newResult.getType() != oldResult.getType())
+                            ? rewriter.create<tensor::CastOp>(
+                                  op->getLoc(), oldResult.getType(), newResult)
+                            : newResult;
+    rewriter.replaceOp(op, {replacement});
+    return success();
+  }
 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
 /// the `tensor.cast` has source that is more static than the consuming op.
@@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp
   LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
                                 PatternRewriter &rewriter) const override {
-    // InsertSliceOp has its own logic about folding tensor.cast ops.
-    if (isa<InsertSliceOp>(op.getOperation()))
-      return failure();
-    // Exclude DPS ops that are also LoopLike from this interface as they
-    // might need special handling of attached regions.
-    if (isa<LoopLikeOpInterface>(op.getOperation()))
+    // Reject tensor::PackOp - there's dedicated pattern for that instead.
+    if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
       return failure();
-    // If no operand comes from a tensor::CastOp and can be folded then fail.
-    bool hasTensorCastOperand =
-        llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
-          if (llvm::isa<BlockArgument>(opOperand.get()))
-            return false;
-          auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-          return castOp && canFoldIntoConsumerOp(castOp);
-        });
-    if (!hasTensorCastOperand)
-      return failure();
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
-    SmallVector<Type, 4> newResultTypes(op->getResultTypes());
-    SmallVector<Value, 4> newOperands;
-    newOperands.reserve(op->getNumOperands());
-    // Assumes that the result has dpsInits followed by nonDpsInits.
-    int64_t dpsInitIdx = 0;
-    for (OpOperand &opOperand : op->getOpOperands()) {
-      auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-      bool fold = canFoldIntoConsumerOp(tensorCastOp);
-      newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
-      if (op.isDpsInit(&opOperand) &&
-          !llvm::isa<MemRefType>(newOperands.back().getType()))
-        newResultTypes[dpsInitIdx++] = newOperands.back().getType();
-    }
+    // Clone op
+    auto newOp = clone(rewriter, op, newResultTypes, newOperands);
-    // Clone op.
-    Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
     SmallVector<Value, 4> replacements;
     for (auto [oldResult, newResult] :
@@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp
 void TensorDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
+  results.add<FoldTensorCastPackOp>(getContext());

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 236d2a3e60eb2c..2186aab9a527cc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
 // -----
-// CHECK-LABEL:   func.func @test_destination_multiple_result(
+// CHECK-LABEL:   func.func @fold_cast_multiple_results(
 // CHECK-SAME:         %[[ARG1:.*]]: tensor<2x2xf32>,
 // CHECK-SAME:         %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
 // CHECK:           %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
 // CHECK-SAME:      outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
 // CHECK:           return %[[RES]]#1 : index
-func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
+func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
   %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
   %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
+// -----
+// CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
+// CHECK-SAME:      %[[DEST:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME:      %[[SRC:.*]]: tensor<7x?xi32>,
+// CHECK-SAME:      %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
+// CHECK:           %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
+// CHECK:           return %[[PACK]] : tensor<1x1x8x1xi32>
+func.func @fold_cast_pack_dynamic_tile_size(
+  %dest: tensor<1x1x8x1xi32>,
+  %src: tensor<7x?xi32>,
+  %pad: i32) -> tensor<1x1x8x1xi32> {
+    %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+    %c8 = arith.constant 8 : index
+    %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
+    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
+    return %res : tensor<1x1x8x1xi32>
 // -----


More information about the Mlir-commits mailing list