[Mlir-commits] [mlir] 9f6a1dd - [mlir][tensor] Introduce `FoldTensorCastUnPackOp` (#121393)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 3 01:09:27 PST 2025
Author: Andrzej WarzyĆski
Date: 2025-01-03T09:09:23Z
New Revision: 9f6a1ddb43133328c90edfa29ccd4c714b289cb6
URL: https://github.com/llvm/llvm-project/commit/9f6a1ddb43133328c90edfa29ccd4c714b289cb6
DIFF: https://github.com/llvm/llvm-project/commit/9f6a1ddb43133328c90edfa29ccd4c714b289cb6.diff
LOG: [mlir][tensor] Introduce `FoldTensorCastUnPackOp` (#121393)
This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by
introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This mirrors a
similar update made for `tensor::PackOp` in #114559. Below is the updated
rationale tailored to `tensor::UnPackOp`.
ISSUE DESCRIPTION
Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
```mlir
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
// Note: `%c8` and `?`.
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
as:
```mlir
// Note: `%c8` and `8`.
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
```
This triggers an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch addresses the
issue by ensuring proper handling of inner tile sizes.
ADDITIONAL CHANGES
* invalid.mlir: Fixed a typo.
* TensorOps.cpp:
* Removed unnecessary `(void)tileSize`.
* Added comments following the discussion in PR #115772.
* Made minor updates to `FoldTensorCastPackOp` for consistency with the newly
introduced `FoldTensorCastUnPackOp`.
* Tensor/canonicalize.mlir: Ensured consistent usage of `test_attr` (e.g.,
replaced mixed use of `test_attr` and `some_attr`).
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..24a1d553153198 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
return newOperands;
}
+// Given the (potentially) updated packed type, `newPackedTy`, generates an
+// updated mixed-tile-sizes attribute. A tile size is updated only
+// when:
+// * a dim from newPackedTy is static, and
+// * the corresponding size from mixedTiles is still dynamic.
+// Otherwise, the original tile size is preserved.
+// Note - packed-type-dim and mixed-tile-size should always match!
+static SmallVector<OpFoldResult>
+getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
+ SmallVector<OpFoldResult> mixedTiles) {
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
+ .getShape()
+ .take_back(mixedTiles.size()),
+ mixedTiles)) {
+ int64_t shape = std::get<0>(it);
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current result dim is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ OpFoldResult tile = std::get<1>(it);
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
+ // Already a constant
+ newMixedTileSizes.push_back(tile);
+ } else {
+ assert(getConstantIntValue(tile).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ return newMixedTileSizes;
+}
+
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
@@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes;
- for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
- .getShape()
- .take_back(op.getMixedTiles().size()),
- op.getMixedTiles())) {
- int64_t shape = std::get<0>(it);
- if (shape == ShapedType::kDynamic) {
- newMixedTileSizes.push_back(std::get<1>(it));
- continue;
- }
-
- if (Attribute attr =
- llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
- // Already a constant
- newMixedTileSizes.push_back(std::get<1>(it));
- } else {
- int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
- assert(tileSize == shape && "tile size and dim size don't match!");
- (void)tileSize;
- newMixedTileSizes.push_back(
- (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
- }
- }
+ SmallVector<OpFoldResult> newMixedTileSizes =
+ getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
// Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ Value sourceTensor = newOperands[0];
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
+ rewriter, sourceTensor.getType(), op.getMixedTiles());
+
+ // Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
+ UnPackOp newOp = rewriter.create<UnPackOp>(
+ op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+ // Replace op.
+ Value oldResult = op.getResult();
+ Value newResult = newOp.getResult();
+ Value replacement = (newResult.getType() != oldResult.getType())
+ ? rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult)
+ : newResult;
+
+ rewriter.replaceOp(op, {replacement});
+
+ return success();
+ }
+};
+
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
@@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {
// Reject tensor::PackOp - there's dedicated pattern for that instead.
- if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+ if (!foldTensorCastPrecondition(op) ||
+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
+ results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e18f..01d14871072cdf 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
+
// -----
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2794,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
-// CHECK-SAME: some_attr
+// CHECK-SAME: test_attr
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
@@ -2807,13 +2808,33 @@ func.func @fold_cast_pack_dynamic_tile_size(
%pack = tensor.pack %src padding_value(%pad : i32)
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
- into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
+ into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}
// -----
+// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK: return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+ %src: tensor<1x1x8x1xi32>,
+ %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+ %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+ %c8 = arith.constant 8 : index
+ %unpack = tensor.unpack %cast
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%c8, 1]
+ into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..1de3e281bc462b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
// -----
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
// expected-error at +1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
More information about the Mlir-commits
mailing list