[Mlir-commits] [mlir] [mlir][scf] fuse `tensor.pack` as consumer (PR #103715)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 14 01:16:01 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: None (Yun-Fly)
<details>
<summary>Changes</summary>
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer.
---
Full diff: https://github.com/llvm/llvm-project/pull/103715.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+91)
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+59)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 361340a4e62f2d..51c232ae77fe6c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -246,6 +246,97 @@ struct PackOpTiling
return failure();
return tilingResult.value();
}
+
+ /// Method to return the position of iteration domain tile computed by the
+ /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
+ /// `resultSizes` only cover outer dimensions.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ auto packOp = cast<PackOp>(op);
+ Location loc = packOp.getLoc();
+
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
+ for (auto dim : packOp.getOuterDimsPerm()) {
+ if (dimAndTileMapping.count(dim)) {
+ FailureOr<int64_t> cstSize =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, sizes[dim],
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ std::optional<int64_t> cstInnerSize =
+ getConstantIntValue(dimAndTileMapping[dim]);
+ // Currently only expect perfect tiling cases.
+ if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+ return failure();
+ }
+
+ using AV = affine::AffineValueExpr;
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr dim0, sym;
+ bindDims(b.getContext(), dim0);
+ bindSymbols(b.getContext(), sym);
+ auto avOffset = AV(dim0).bind(offsets[dim]);
+ auto avSize = AV(dim0).bind(sizes[dim]);
+ auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
+ outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
+ outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
+ } else {
+ outerDimOffsets.push_back(offsets[dim]);
+ outerDimSizes.push_back(sizes[dim]);
+ }
+ }
+
+ resultOffsets = outerDimOffsets;
+ resultSizes = outerDimSizes;
+ return success();
+ }
+
+ /// Method to return the tiled implementation of tensor.pack as a consumer.
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ auto packOp = cast<PackOp>(op);
+ Location loc = packOp.getLoc();
+
+ int64_t inputRank = packOp.getSourceRank();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(inputRank, oneAttr);
+
+ SmallVector<Value> tiledOperands;
+ tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
+ offsets, sizes, strides));
+
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+ if (failed(getIterationDomainTileFromOperandTile(
+ op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+ outerDimSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> outputOffsets, outputSizes;
+ if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
+ outputOffsets, outputSizes)))
+ return failure();
+
+ strides.append(packOp.getDestRank() - inputRank, oneAttr);
+ auto extractSlice = b.create<ExtractSliceOp>(
+ loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+ tiledOperands.push_back(extractSlice);
+
+ if (auto val = packOp.getPaddingValue())
+ tiledOperands.push_back(val);
+ for (auto tile : packOp.getInnerTiles())
+ tiledOperands.push_back(tile);
+
+ Operation *tiledPackOp = b.create<PackOp>(
+ loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+
+ return TilingResult{{tiledPackOp},
+ SmallVector<Value>(tiledPackOp->getResults())};
+ }
};
struct UnpackTileDimInfo {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcda..741dfbfb1cd5c2 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %output = tensor.empty() : tensor<4x32x16xf32>
+ %pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+ return %pack : tensor<4x32x16xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
+// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+// CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
+// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
``````````
</details>
https://github.com/llvm/llvm-project/pull/103715
More information about the Mlir-commits
mailing list