[Mlir-commits] [mlir] [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) (PR #114559)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 1 09:09:59 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
```mlir
    %pack = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
```
as (note the static trailing dim in the result and dynamic tile
dimension that corresponds to that):
```mlir
    %res = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
```

This triggers an Op verification failure and is due to the fact that the
folder does not update the inner tile sizes in the pack Op. This PR
addresses that.

Note, supporting other Ops with size-like attributes is left as a TODO;


---
Full diff: https://github.com/llvm/llvm-project/pull/114559.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+44-2) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+21-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c2d6bc610cd92a..406b557b0f0e39 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4756,8 +4756,50 @@ struct FoldTensorCastProducerOp
         newResultTypes[dpsInitIdx++] = newOperands.back().getType();
     }
 
-    // Clone op.
-    Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+    // For ops that have sizes-like attribute, update these accordingly.
+    // For now, only `tensor.pack` is supported.
+    // TODO: Generalize to make it work with other ops as well (e.g.
+    // `tensor.unpack`)
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    if (auto pack = dyn_cast_or_null<tensor::PackOp>(*op)) {
+      for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
+                                   .getShape()
+                                   .take_back(pack.getMixedTiles().size()),
+                               pack.getMixedTiles())) {
+
+        int64_t shape = std::get<0>(it);
+        if (shape == ShapedType::kDynamic) {
+          newMixedTileSizes.push_back(std::get<1>(it));
+          continue;
+        }
+
+        if (Attribute attr =
+                llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+          // Already a constant
+          newMixedTileSizes.push_back(std::get<1>(it));
+        } else {
+          auto tileSize = getConstantIntValue(std::get<1>(it));
+          assert(tileSize == shape && "tile size and dim size don't match!");
+          newMixedTileSizes.push_back(
+              (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+        }
+      }
+    }
+
+    // Clone op. For ops that have sizes-like attribute, make sure to udpate
+    // those as well. For now, only `tensor.pack` is supported.
+    // TODO: Generalize to make it work with other ops as well (e.g.
+    // `tensor.unpack`)
+    // Operation *newOp;
+    Operation *newOp;
+    if (auto pack = dyn_cast_or_null<tensor::PackOp>(*op)) {
+      newOp = rewriter.create<PackOp>(
+          pack.getLoc(), newOperands[0], newOperands[1], pack.getInnerDimsPos(),
+          newMixedTileSizes, pack.getPaddingValue(), pack.getOuterDimsPerm());
+    } else {
+      newOp = clone(rewriter, op, newResultTypes, newOperands);
+    }
+
     SmallVector<Value, 4> replacements;
     replacements.reserve(newOp->getNumResults());
     for (auto [oldResult, newResult] :
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 693079c3aa2fac..ebcc69250ad56d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @test_destination_multiple_result(
+// CHECK-LABEL:   func.func @fold_cast_multiple_results(
 // CHECK-SAME:         %[[ARG1:.*]]: tensor<2x2xf32>,
 // CHECK-SAME:         %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
 // CHECK:           %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
 // CHECK-SAME:      outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
 // CHECK:           return %[[RES]]#1 : index
-func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
+func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
   %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
   %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
 }
+// -----
+
+// CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
+// CHECK-SAME:      %[[DEST:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME:      %[[SRC:.*]]: tensor<7x?xi32>,
+// CHECK-SAME:      %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
+// CHECK:           %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
+// CHECK:           return %[[PACK]] : tensor<1x1x8x1xi32>
+func.func @fold_cast_pack_dynamic_tile_size(
+  %dest: tensor<1x1x8x1xi32>,
+  %src: tensor<7x?xi32>,
+  %pad: i32) -> tensor<1x1x8x1xi32> {
+
+    %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+    %c8 = arith.constant 8 : index
+    %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
+    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
+    return %res : tensor<1x1x8x1xi32>
+}
 
 // -----
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/114559


More information about the Mlir-commits mailing list