[Mlir-commits] [mlir] 8b68cec - [mlir][tensor] Add producer fusion for tensor.pack op.
Hanhan Wang
llvmlistbot at llvm.org
Wed Aug 16 11:03:09 PDT 2023
Author: Hanhan Wang
Date: 2023-08-16T11:02:59-07:00
New Revision: 8b68cec9c0a42bbba04e7ed2e01136c16236cd50
URL: https://github.com/llvm/llvm-project/commit/8b68cec9c0a42bbba04e7ed2e01136c16236cd50
DIFF: https://github.com/llvm/llvm-project/commit/8b68cec9c0a42bbba04e7ed2e01136c16236cd50.diff
LOG: [mlir][tensor] Add producer fusion for tensor.pack op.
We are able to fuse the pack op only if inner tiles are not tiled or
they are fully used. Otherwise, it could generate a sequence of
non-trivial ops.
Differential Revision: https://reviews.llvm.org/D157932
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 56cfdfcf0b8b92..67080d8e301c13 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -220,6 +220,32 @@ struct PackOpTiling
return success();
}
+
+ FailureOr<TilingResult>
+ generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ auto packOp = cast<PackOp>(op);
+ int64_t numTiles = packOp.getInnerDimsPos().size();
+
+ // tensor.pack op is fusible (as a producer) only if full inner tiles are
+ // iterated or inner dims are not tiled. Otherwise, it will generate a
+ // sequence of non-trivial ops (for partial tiles).
+ for (auto offset : offsets.take_back(numTiles))
+ if (!isConstantIntValue(offset, 0))
+ return failure();
+
+ for (auto iter :
+ llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
+ if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
+ return failure();
+
+ FailureOr<TilingResult> tilingResult = getTiledImplementation(
+ op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
+ if (failed(tilingResult))
+ return failure();
+ return tilingResult.value();
+ }
};
struct UnpackTileDimInfo {
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 20beefb351f81d..b9f12fcc3057a6 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -118,3 +118,51 @@ transform.sequence failures(propagate) {
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_elemwise
+// CHECK: %[[RES:.*]] = scf.for
+// CHECK: scf.for
+// CHECK: tensor.pack
+// CHECK: linalg.elemwise_unary
+// CHECK: return %[[RES]]
+func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> {
+ %0 = tensor.empty() : tensor<16x48x8x8xf32>
+ %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
+ : tensor<128x384xf32> -> tensor<16x48x8x8xf32>
+ %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>)
+ outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32>
+ return %2 : tensor<16x48x8x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]}
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @nofuse_pack_elemwise
+// CHECK: tensor.pack
+// CHECK: %[[RES:.*]] = scf.for
+// CHECK: scf.for
+// CHECK: linalg.elemwise_unary
+// CHECK: return %[[RES]]
+func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> {
+ %0 = tensor.empty() : tensor<16x48x8x8xf32>
+ %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
+ : tensor<128x384xf32> -> tensor<16x48x8x8xf32>
+ %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>)
+ outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32>
+ return %2 : tensor<16x48x8x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]}
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
More information about the Mlir-commits
mailing list