[Mlir-commits] [mlir] [mlir][tensor] Introduce `FoldTensorCastUnPackOp` (PR #121393)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jan 2 09:20:40 PST 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/121393
>From 1bc2d8eaced67b9e2e4a6893e18db49a76a4f61b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 31 Dec 2024 14:02:45 +0000
Subject: [PATCH 1/2] [mlir][tensor] Introduce `FoldTensorCastUnPackOp`
This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by
introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change
mirrors a similar update made for `tensor::PackOp` in #114559. Below is
the updated rationale for `tensor::UnPackOp`.
Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
```mlir
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
as:
```mlir
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
This leads to an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch resolves the
issue.
Additional Changes:
* invalid.mlir: Fixes a typo.
* TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra
comments following this discussion:
https://github.com/llvm/llvm-project/pull/115772.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 89 +++++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 21 +++++
mlir/test/Dialect/Tensor/invalid.mlir | 2 +-
3 files changed, 107 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..aeb11186c124da 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
- int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
- assert(tileSize == shape && "tile size and dim size don't match!");
- (void)tileSize;
+ assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+ "tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}
// Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ Value sourceTensor = newOperands[0];
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
+ .getShape()
+ .take_back(op.getMixedTiles().size()),
+ op.getMixedTiles())) {
+ int64_t shape = std::get<0>(it);
+ // If the current source shape is dynamic, just preserve this mixed
+ // size.
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current source is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ if (Attribute attr =
+ llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+ // Already a constant
+ newMixedTileSizes.push_back(std::get<1>(it));
+ } else {
+ assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ // Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
+ UnPackOp newOp = rewriter.create<UnPackOp>(
+ op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+ // Replace op.
+ Value oldResult = op.getResult();
+ Value newResult = newOp.getResult();
+ Value replacement = (newResult.getType() != oldResult.getType())
+ ? rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult)
+ : newResult;
+
+ rewriter.replaceOp(op, {replacement});
+
+ return success();
+ }
+};
+
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
@@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {
// Reject tensor::PackOp - there's dedicated pattern for that instead.
- if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+ if (!foldTensorCastPrecondition(op) ||
+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
+ results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e18f..88e3691e2d6297 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
+
// -----
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(
// -----
+// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK: return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+ %src: tensor<1x1x8x1xi32>,
+ %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+ %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+ %c8 = arith.constant 8 : index
+ %unpack = tensor.unpack %cast
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%c8, 1]
+ into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..1de3e281bc462b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
// -----
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
// expected-error at +1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
>From 10d26d94d6095cbd3f7a61ae517b5019d5ebbaaa Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 2 Jan 2025 17:19:34 +0000
Subject: [PATCH 2/2] fixup! [mlir][tensor] Introduce `FoldTensorCastUnPackOp`
Address PR comments
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 92 ++++++++++------------
mlir/test/Dialect/Tensor/canonicalize.mlir | 8 +-
2 files changed, 47 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index aeb11186c124da..24a1d553153198 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
return newOperands;
}
+// Given the (potentially) updated packed type, `newPackedTy`, generates an
+// updated mixed-tile-sizes attribute. A tile size is updated only
+// when:
+// * a dim from newPackedTy is static, and
+// * the corresponding size from mixedTiles is still dynamic.
+// Otherwise, the original tile size is preserved.
+// Note - packed-type-dim and mixed-tile-size should always match!
+static SmallVector<OpFoldResult>
+getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
+ SmallVector<OpFoldResult> mixedTiles) {
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
+ .getShape()
+ .take_back(mixedTiles.size()),
+ mixedTiles)) {
+ int64_t shape = std::get<0>(it);
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current result dim is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ OpFoldResult tile = std::get<1>(it);
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
+ // Already a constant
+ newMixedTileSizes.push_back(tile);
+ } else {
+ assert(getConstantIntValue(tile).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ return newMixedTileSizes;
+}
+
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
@@ -4821,28 +4859,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes;
- for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
- .getShape()
- .take_back(op.getMixedTiles().size()),
- op.getMixedTiles())) {
- int64_t shape = std::get<0>(it);
- if (shape == ShapedType::kDynamic) {
- newMixedTileSizes.push_back(std::get<1>(it));
- continue;
- }
-
- if (Attribute attr =
- llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
- // Already a constant
- newMixedTileSizes.push_back(std::get<1>(it));
- } else {
- assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
- "tile size and dim size don't match!");
- newMixedTileSizes.push_back(
- (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
- }
- }
+ SmallVector<OpFoldResult> newMixedTileSizes =
+ getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
// Clone op.
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -4873,7 +4891,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
-/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
/// ```
///
/// folds into:
@@ -4894,32 +4912,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
Value sourceTensor = newOperands[0];
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes;
- for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
- .getShape()
- .take_back(op.getMixedTiles().size()),
- op.getMixedTiles())) {
- int64_t shape = std::get<0>(it);
- // If the current source shape is dynamic, just preserve this mixed
- // size.
- if (shape == ShapedType::kDynamic) {
- newMixedTileSizes.push_back(std::get<1>(it));
- continue;
- }
-
- // If the current source is static, update the dynamic mixed-size
- // (provided the original value is dynamic).
- if (Attribute attr =
- llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
- // Already a constant
- newMixedTileSizes.push_back(std::get<1>(it));
- } else {
- assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
- "tile size and dim size don't match!");
- newMixedTileSizes.push_back(
- (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
- }
- }
+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
+ rewriter, sourceTensor.getType(), op.getMixedTiles());
// Clone op.
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 88e3691e2d6297..01d14871072cdf 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2795,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
-// CHECK-SAME: some_attr
+// CHECK-SAME: test_attr
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
@@ -2808,7 +2808,7 @@ func.func @fold_cast_pack_dynamic_tile_size(
%pack = tensor.pack %src padding_value(%pad : i32)
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
- into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
+ into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}
@@ -2818,7 +2818,7 @@ func.func @fold_cast_pack_dynamic_tile_size(
// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
-// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
// CHECK: return %[[RES]] : tensor<7x?xi32>
func.func @fold_cast_unpack_dynamic_tile_size(
%src: tensor<1x1x8x1xi32>,
@@ -2829,7 +2829,7 @@ func.func @fold_cast_unpack_dynamic_tile_size(
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
- into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
return %unpack : tensor<7x?xi32>
}
More information about the Mlir-commits
mailing list