[Mlir-commits] [mlir] [mlir][tensor] Introduce `FoldTensorCastUnPackOp` (PR #121393)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 31 06:33:36 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by
introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change
mirrors a similar update made for `tensor::PackOp` in #<!-- -->114559. Below is
the updated rationale for `tensor::UnPackOp`.
Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
```mlir
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
as:
```mlir
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
This leads to an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch resolves the
issue.
Additional Changes:
* invalid.mlir: Fixes a typo.
* TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra
comments following this discussion:
https://github.com/llvm/llvm-project/pull/115772.
---
Full diff: https://github.com/llvm/llvm-project/pull/121393.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+84-4)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+21)
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+1-1)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..ee9a4012a01393 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
- int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
- assert(tileSize == shape && "tile size and dim size don't match!");
- (void)tileSize;
+ assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+ "tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}
// Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ Value sourceTensor = newOperands[0];
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
+ .getShape()
+ .take_back(op.getMixedTiles().size()),
+ op.getMixedTiles())) {
+ int64_t shape = std::get<0>(it);
+ // If the current source shape is dynamic, just preserve this mixed
+ // size.
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current source is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ if (Attribute attr =
+ llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+ // Already a constant
+ newMixedTileSizes.push_back(std::get<1>(it));
+ } else {
+ assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ // Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
+ UnPackOp newOp = rewriter.create<UnPackOp>(
+ op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+ // Replace op.
+ Value oldResult = op.getResult();
+ Value newResult = newOp.getResult();
+ Value replacement = (newResult.getType() != oldResult.getType())
+ ? rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult)
+ : newResult;
+
+ rewriter.replaceOp(op, {replacement});
+
+ return success();
+ }
+};
+
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
@@ -4890,7 +4969,7 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {
// Reject tensor::PackOp - there's dedicated pattern for that instead.
- if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+ if (!foldTensorCastPrecondition(op) || isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +5002,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
+ results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e18f..88e3691e2d6297 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
+
// -----
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(
// -----
+// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK: return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+ %src: tensor<1x1x8x1xi32>,
+ %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+ %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+ %c8 = arith.constant 8 : index
+ %unpack = tensor.unpack %cast
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%c8, 1]
+ into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..1de3e281bc462b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
// -----
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
// expected-error at +1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/121393
More information about the Mlir-commits
mailing list