[Mlir-commits] [mlir] [mlir][scf] fuse `tensor.pack` as consumer (PR #103715)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 15 19:44:17 PDT 2024
================
@@ -246,6 +246,97 @@ struct PackOpTiling
return failure();
return tilingResult.value();
}
+
+ /// Method to return the position of iteration domain tile computed by the
+ /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
+ /// `resultSizes` only cover outer dimensions.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ auto packOp = cast<PackOp>(op);
+ Location loc = packOp.getLoc();
+
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
+ for (auto dim : packOp.getOuterDimsPerm()) {
+ if (dimAndTileMapping.count(dim)) {
+ FailureOr<int64_t> cstSize =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, sizes[dim],
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ std::optional<int64_t> cstInnerSize =
+ getConstantIntValue(dimAndTileMapping[dim]);
+ // Currently only expect perfect tiling cases.
+ if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+ return failure();
+ }
----------------
Yun-Fly wrote:
IIIUC, non-perfect tiling case exists even if without padding semantics just like what you left in comment:
```
/// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
/// coordinates of second tile (i.e., result[15..31]) are
/// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
/// row are incomplete tiles. To represent the unpack op, we have to complete
/// the rows. I.e., the input coordinates would start with (1, 0); end with
/// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
/// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
/// can get the actual result.
````
which will involve incomplete tiles and is not trivial to follow. So I just add this constraint here to bypass complex case. Perhaps, you can help to refine this feature after this patch :)
https://github.com/llvm/llvm-project/pull/103715
More information about the Mlir-commits
mailing list