[Mlir-commits] [mlir] [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) (PR #114559)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 5 09:59:35 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/114559

>From d40f7052348001349164d13a50c2beff164373e8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 1 Nov 2024 15:59:47 +0000
Subject: [PATCH 1/3] [mlir][tensor] Improve `FoldTensorCastProducerOp`
 (dynamic shapes)

Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
```mlir
    %pack = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
```
as (note the static trailing dim in the result and dynamic tile
dimension that corresponds to that):
```mlir
    %res = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
```

This triggers an Op verification failure and is due to the fact that the
folder does not update the inner tile sizes in the pack Op. This PR
addresses that.

Note, supporting other Ops with size-like attributes is left as a TODO;
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 46 +++++++++++++++++++++-
 mlir/test/Dialect/Tensor/canonicalize.mlir | 23 ++++++++++-
 2 files changed, 65 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c2d6bc610cd92a..406b557b0f0e39 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4756,8 +4756,50 @@ struct FoldTensorCastProducerOp
         newResultTypes[dpsInitIdx++] = newOperands.back().getType();
     }
 
-    // Clone op.
-    Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+    // For ops that have sizes-like attribute, update these accordingly.
+    // For now, only `tensor.pack` is supported.
+    // TODO: Generalize to make it work with other ops as well (e.g.
+    // `tensor.unpack`)
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    if (auto pack = dyn_cast_or_null<tensor::PackOp>(*op)) {
+      for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
+                                   .getShape()
+                                   .take_back(pack.getMixedTiles().size()),
+                               pack.getMixedTiles())) {
+
+        int64_t shape = std::get<0>(it);
+        if (shape == ShapedType::kDynamic) {
+          newMixedTileSizes.push_back(std::get<1>(it));
+          continue;
+        }
+
+        if (Attribute attr =
+                llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+          // Already a constant
+          newMixedTileSizes.push_back(std::get<1>(it));
+        } else {
+          auto tileSize = getConstantIntValue(std::get<1>(it));
+          assert(tileSize == shape && "tile size and dim size don't match!");
+          newMixedTileSizes.push_back(
+              (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+        }
+      }
+    }
+
+    // Clone op. For ops that have sizes-like attribute, make sure to udpate
+    // those as well. For now, only `tensor.pack` is supported.
+    // TODO: Generalize to make it work with other ops as well (e.g.
+    // `tensor.unpack`)
+    // Operation *newOp;
+    Operation *newOp;
+    if (auto pack = dyn_cast_or_null<tensor::PackOp>(*op)) {
+      newOp = rewriter.create<PackOp>(
+          pack.getLoc(), newOperands[0], newOperands[1], pack.getInnerDimsPos(),
+          newMixedTileSizes, pack.getPaddingValue(), pack.getOuterDimsPerm());
+    } else {
+      newOp = clone(rewriter, op, newResultTypes, newOperands);
+    }
+
     SmallVector<Value, 4> replacements;
     replacements.reserve(newOp->getNumResults());
     for (auto [oldResult, newResult] :
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 693079c3aa2fac..ebcc69250ad56d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @test_destination_multiple_result(
+// CHECK-LABEL:   func.func @fold_cast_multiple_results(
 // CHECK-SAME:         %[[ARG1:.*]]: tensor<2x2xf32>,
 // CHECK-SAME:         %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
 // CHECK:           %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
 // CHECK-SAME:      outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
 // CHECK:           return %[[RES]]#1 : index
-func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
+func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
   %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
   %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
 }
+// -----
+
+// CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
+// CHECK-SAME:      %[[DEST:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME:      %[[SRC:.*]]: tensor<7x?xi32>,
+// CHECK-SAME:      %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
+// CHECK:           %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
+// CHECK:           return %[[PACK]] : tensor<1x1x8x1xi32>
+func.func @fold_cast_pack_dynamic_tile_size(
+  %dest: tensor<1x1x8x1xi32>,
+  %src: tensor<7x?xi32>,
+  %pad: i32) -> tensor<1x1x8x1xi32> {
+
+    %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+    %c8 = arith.constant 8 : index
+    %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
+    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
+    return %res : tensor<1x1x8x1xi32>
+}
 
 // -----
 

>From b7b56b1f7b308f854424da0ee0b927401b5ae4d0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 4 Nov 2024 19:01:39 +0000
Subject: [PATCH 2/3] fixup! [mlir][tensor] Improve `FoldTensorCastProducerOp`
 (dynamic shapes)

Split into two patterns.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 187 ++++++++++++++---------
 1 file changed, 114 insertions(+), 73 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 406b557b0f0e39..2f0d7d441e19ce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4698,6 +4698,114 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 // Common Canonicalizers and Folders.
 //===----------------------------------------------------------------------===//
+bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+  // InsertSliceOp has its own logic about folding tensor.cast ops.
+  if (isa<InsertSliceOp>(op.getOperation()))
+    return false;
+
+  // Exclude DPS ops that are also LoopLike from this interface as they
+  // might need special handling of attached regions.
+  if (isa<LoopLikeOpInterface>(op.getOperation()))
+    return false;
+
+  // If no operand comes from a tensor::CastOp and can be folded then fail.
+  bool hasTensorCastOperand =
+      llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+        if (llvm::isa<BlockArgument>(opOperand.get()))
+          return false;
+        auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+        return castOp && canFoldIntoConsumerOp(castOp);
+      });
+
+  return hasTensorCastOperand;
+}
+
+static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
+                                         SmallVector<Type> &newResTy) {
+  SmallVector<Value> newOperands;
+  newOperands.reserve(op->getNumOperands());
+
+  // Assumes that the result has dpsInits followed by nonDpsInits.
+  int64_t dpsInitIdx = 0;
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    bool fold = canFoldIntoConsumerOp(tensorCastOp);
+    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+    if (op.isDpsInit(&opOperand) &&
+        !llvm::isa<MemRefType>(newOperands.back().getType()))
+      newResTy[dpsInitIdx++] = newOperands.back().getType();
+  }
+  return newOperands;
+}
+
+/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
+/// ```
+struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!foldTensorCastPrecondition(op))
+      return failure();
+
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+
+    // Get the updated mixed-tile-sizes attribute.
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
+                                 .getShape()
+                                 .take_back(op.getMixedTiles().size()),
+                             op.getMixedTiles())) {
+      int64_t shape = std::get<0>(it);
+      if (shape == ShapedType::kDynamic) {
+        newMixedTileSizes.push_back(std::get<1>(it));
+        continue;
+      }
+
+      if (Attribute attr =
+              llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+        // Already a constant
+        newMixedTileSizes.push_back(std::get<1>(it));
+      } else {
+        auto tileSize = getConstantIntValue(std::get<1>(it));
+        assert(tileSize == shape && "tile size and dim size don't match!");
+        newMixedTileSizes.push_back(
+            (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+      }
+    }
+
+    // Clone op.
+    PackOp newOp = rewriter.create<PackOp>(
+        op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
+        newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
+
+    SmallVector<Value, 4> replacements;
+    replacements.reserve(newOp->getNumResults());
+    for (auto [oldResult, newResult] :
+         llvm::zip(op->getResults(), newOp->getResults())) {
+      newResult.getType() != oldResult.getType()
+          ? replacements.push_back(rewriter.create<tensor::CastOp>(
+                op->getLoc(), oldResult.getType(), newResult))
+          : replacements.push_back(newResult);
+    }
+    rewriter.replaceOp(op, replacements);
+
+    return success();
+  }
+};
 
 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
 /// the `tensor.cast` has source that is more static than the consuming op.
@@ -4722,83 +4830,15 @@ struct FoldTensorCastProducerOp
 
   LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
                                 PatternRewriter &rewriter) const override {
-    // InsertSliceOp has its own logic about folding tensor.cast ops.
-    if (isa<InsertSliceOp>(op.getOperation()))
-      return failure();
-
-    // Exclude DPS ops that are also LoopLike from this interface as they
-    // might need special handling of attached regions.
-    if (isa<LoopLikeOpInterface>(op.getOperation()))
-      return failure();
 
-    // If no operand comes from a tensor::CastOp and can be folded then fail.
-    bool hasTensorCastOperand =
-        llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
-          if (llvm::isa<BlockArgument>(opOperand.get()))
-            return false;
-          auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-          return castOp && canFoldIntoConsumerOp(castOp);
-        });
-    if (!hasTensorCastOperand)
+    if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
       return failure();
 
-    SmallVector<Type, 4> newResultTypes(op->getResultTypes());
-    SmallVector<Value, 4> newOperands;
-    newOperands.reserve(op->getNumOperands());
-    // Assumes that the result has dpsInits followed by nonDpsInits.
-    int64_t dpsInitIdx = 0;
-    for (OpOperand &opOperand : op->getOpOperands()) {
-      auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-      bool fold = canFoldIntoConsumerOp(tensorCastOp);
-      newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
-      if (op.isDpsInit(&opOperand) &&
-          !llvm::isa<MemRefType>(newOperands.back().getType()))
-        newResultTypes[dpsInitIdx++] = newOperands.back().getType();
-    }
-
-    // For ops that have sizes-like attribute, update these accordingly.
-    // For now, only `tensor.pack` is supported.
-    // TODO: Generalize to make it work with other ops as well (e.g.
-    // `tensor.unpack`)
-    SmallVector<OpFoldResult> newMixedTileSizes;
-    if (auto pack = dyn_cast_or_null<tensor::PackOp>(*op)) {
-      for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
-                                   .getShape()
-                                   .take_back(pack.getMixedTiles().size()),
-                               pack.getMixedTiles())) {
-
-        int64_t shape = std::get<0>(it);
-        if (shape == ShapedType::kDynamic) {
-          newMixedTileSizes.push_back(std::get<1>(it));
-          continue;
-        }
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
 
-        if (Attribute attr =
-                llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
-          // Already a constant
-          newMixedTileSizes.push_back(std::get<1>(it));
-        } else {
-          auto tileSize = getConstantIntValue(std::get<1>(it));
-          assert(tileSize == shape && "tile size and dim size don't match!");
-          newMixedTileSizes.push_back(
-              (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
-        }
-      }
-    }
-
-    // Clone op. For ops that have sizes-like attribute, make sure to udpate
-    // those as well. For now, only `tensor.pack` is supported.
-    // TODO: Generalize to make it work with other ops as well (e.g.
-    // `tensor.unpack`)
-    // Operation *newOp;
-    Operation *newOp;
-    if (auto pack = dyn_cast_or_null<tensor::PackOp>(*op)) {
-      newOp = rewriter.create<PackOp>(
-          pack.getLoc(), newOperands[0], newOperands[1], pack.getInnerDimsPos(),
-          newMixedTileSizes, pack.getPaddingValue(), pack.getOuterDimsPerm());
-    } else {
-      newOp = clone(rewriter, op, newResultTypes, newOperands);
-    }
+    // Clone op
+    auto newOp = clone(rewriter, op, newResultTypes, newOperands);
 
     SmallVector<Value, 4> replacements;
     replacements.reserve(newOp->getNumResults());
@@ -4823,6 +4863,7 @@ struct FoldTensorCastProducerOp
 
 void TensorDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
+  results.add<FoldTensorCastPackOp>(getContext());
   results.add<FoldTensorCastProducerOp>(getContext());
 }
 

>From 94316c7adedf626b13d6a028ad47f0d4373bf0a5 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 5 Nov 2024 17:59:00 +0000
Subject: [PATCH 3/3] fixup! fixup! [mlir][tensor] Improve
 `FoldTensorCastProducerOp` (dynamic shapes)

Final tweaks
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 32 +++++++++++-------------
 1 file changed, 15 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2f0d7d441e19ce..1847066b2d1e36 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4699,13 +4699,11 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
 // Common Canonicalizers and Folders.
 //===----------------------------------------------------------------------===//
 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
-  // InsertSliceOp has its own logic about folding tensor.cast ops.
-  if (isa<InsertSliceOp>(op.getOperation()))
-    return false;
-
-  // Exclude DPS ops that are also LoopLike from this interface as they
+  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
+  // 2. Exclude DPS ops that are also LoopLike from this interface as they
   // might need special handling of attached regions.
-  if (isa<LoopLikeOpInterface>(op.getOperation()))
+  if (isa<InsertSliceOp>(op.getOperation()) ||
+      isa<LoopLikeOpInterface>(op.getOperation()))
     return false;
 
   // If no operand comes from a tensor::CastOp and can be folded then fail.
@@ -4780,7 +4778,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         // Already a constant
         newMixedTileSizes.push_back(std::get<1>(it));
       } else {
-        auto tileSize = getConstantIntValue(std::get<1>(it));
+        int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
         assert(tileSize == shape && "tile size and dim size don't match!");
         newMixedTileSizes.push_back(
             (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
@@ -4792,16 +4790,15 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
         newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
 
-    SmallVector<Value, 4> replacements;
-    replacements.reserve(newOp->getNumResults());
-    for (auto [oldResult, newResult] :
-         llvm::zip(op->getResults(), newOp->getResults())) {
-      newResult.getType() != oldResult.getType()
-          ? replacements.push_back(rewriter.create<tensor::CastOp>(
-                op->getLoc(), oldResult.getType(), newResult))
-          : replacements.push_back(newResult);
-    }
-    rewriter.replaceOp(op, replacements);
+    // Replace op.
+    Value oldResult = op.getResult();
+    Value newResult = newOp.getResult();
+    Value replacement = (newResult.getType() != oldResult.getType())
+                            ? rewriter.create<tensor::CastOp>(
+                                  op->getLoc(), oldResult.getType(), newResult)
+                            : newResult;
+
+    rewriter.replaceOp(op, {replacement});
 
     return success();
   }
@@ -4831,6 +4828,7 @@ struct FoldTensorCastProducerOp
   LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
                                 PatternRewriter &rewriter) const override {
 
+    // Reject tensor::PackOp - there's dedicated pattern for that instead.
     if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
       return failure();
 



More information about the Mlir-commits mailing list