[Mlir-commits] [mlir] [mlir][tensor] Introduce `FoldTensorCastUnPackOp` (PR #121393)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 31 06:33:36 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by
introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change
mirrors a similar update made for `tensor::PackOp` in #<!-- -->114559. Below is
the updated rationale for `tensor::UnPackOp`.

Currently, `FoldTensorCastProducerOp` incorrectly folds the following:

```mlir
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
as:

```mlir
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```

This leads to an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch resolves the
issue.

Additional Changes:
* invalid.mlir: Fixes a typo.
* TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra
  comments following this discussion:
  https://github.com/llvm/llvm-project/pull/115772.


---
Full diff: https://github.com/llvm/llvm-project/pull/121393.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+84-4) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+21) 
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+1-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..ee9a4012a01393 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         // Already a constant
         newMixedTileSizes.push_back(std::get<1>(it));
       } else {
-        int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
-        assert(tileSize == shape && "tile size and dim size don't match!");
-        (void)tileSize;
+        assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+               "tile size and dim size don't match!");
         newMixedTileSizes.push_back(
             (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
       }
     }
 
     // Clone op.
+    // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+    // this point. However, in practice, we use them for things that we'd like
+    // to preserve. Implement a better abstraction.
     PackOp newOp = rewriter.create<PackOp>(
         op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
         newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
   }
 };
 
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+///   %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = tensor.unpack %0  ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(UnPackOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!foldTensorCastPrecondition(op))
+      return failure();
+
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    Value sourceTensor = newOperands[0];
+
+    // Get the updated mixed-tile-sizes attribute.
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
+                                 .getShape()
+                                 .take_back(op.getMixedTiles().size()),
+                             op.getMixedTiles())) {
+      int64_t shape = std::get<0>(it);
+      // If the current source shape is dynamic, just preserve this mixed
+      // size.
+      if (shape == ShapedType::kDynamic) {
+        newMixedTileSizes.push_back(std::get<1>(it));
+        continue;
+      }
+
+      // If the current source is static, update the dynamic mixed-size
+      // (provided the original value is dynamic).
+      if (Attribute attr =
+              llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+        // Already a constant
+        newMixedTileSizes.push_back(std::get<1>(it));
+      } else {
+        assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+               "tile size and dim size don't match!");
+        newMixedTileSizes.push_back(
+            (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+      }
+    }
+
+    // Clone op.
+    // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+    // this point. However, in practice, we use them for things that we'd like
+    // to preserve. Implement a better abstraction.
+    UnPackOp newOp = rewriter.create<UnPackOp>(
+        op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+        newMixedTileSizes,  op.getOuterDimsPerm());
+    newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+    // Replace op.
+    Value oldResult = op.getResult();
+    Value newResult = newOp.getResult();
+    Value replacement = (newResult.getType() != oldResult.getType())
+                            ? rewriter.create<tensor::CastOp>(
+                                  op->getLoc(), oldResult.getType(), newResult)
+                            : newResult;
+
+    rewriter.replaceOp(op, {replacement});
+
+    return success();
+  }
+};
+
 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
 /// the `tensor.cast` has source that is more static than the consuming op.
 ///
@@ -4890,7 +4969,7 @@ struct FoldTensorCastProducerOp
                                 PatternRewriter &rewriter) const override {
 
     // Reject tensor::PackOp - there's dedicated pattern for that instead.
-    if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+    if (!foldTensorCastPrecondition(op) || isa<tensor::PackOp, tensor::UnPackOp>(*op))
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +5002,7 @@ struct FoldTensorCastProducerOp
 void TensorDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
   results.add<FoldTensorCastPackOp>(getContext());
+  results.add<FoldTensorCastUnPackOp>(getContext());
   results.add<FoldTensorCastProducerOp>(getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e18f..88e3691e2d6297 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
 }
+
 // -----
 
 // CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(
 
 // -----
 
+// CHECK-LABEL:   func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME:      %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK:           %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK:           return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+  %src: tensor<1x1x8x1xi32>,
+  %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+    %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+    %c8 = arith.constant 8 : index
+    %unpack = tensor.unpack %cast
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%c8, 1]
+      into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+    return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @pack_dont_drop_attributes(
 // CHECK: tensor.pack {{.*}}  {test_attr}
 func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..1de3e281bc462b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
 
 // -----
 
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
   // expected-error at +1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
   %0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/121393


More information about the Mlir-commits mailing list