[Mlir-commits] [mlir] Full slices when tiling full loop trip count (PR #127197)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 14 03:09:40 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (josel-amd)
<details>
<summary>Changes</summary>
When tiling a chain of linalg.ops, we can only set the tile sizes of the first one to 0 to say untiled, but producers of it will get a tile size of <loop trip count>. We must return the full slice in those cases because the code that computes the slices sizes in the general case doesn't handle non-monotonic affine expressions. Otherwise we would generate invalid code for non-monotonic expressions even if all involved dimensions are effectively untiled.
---
Full diff: https://github.com/llvm/llvm-project/pull/127197.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+6-3)
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+25-10)
- (modified) mlir/test/Dialect/Linalg/tile-tensors.mlir (+32)
- (modified) mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (+2-1)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index b7764da26a7f4..a838b99c9dbb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -115,13 +115,16 @@ struct LinalgOpTilingInterface
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
- // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
- // specified could lead to out of bounds accesses.
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
+ SmallVector<OpFoldResult> allShapeSizes =
+ linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
+ SmallVector<OpFoldResult> sizeBounds =
+ mlir::affine::makeComposedFoldedMultiResultAffineApply(
+ b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value> tiledOperands = makeTiledShapes(
- b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+ b, loc, linalgOp, valuesToTile, offsets, sizes, sizeBounds, true);
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
llvm::make_filter_range(
tiledOperands,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d148067fe6343..3f0382f4bc8b6 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -56,10 +56,23 @@ namespace {
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
//
struct TileCheck : public AffineExprVisitor<TileCheck> {
- TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
+ TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
+ : tileSizes(tileSizes), sizeBounds(sizeBounds) {}
void visitDimExpr(AffineDimExpr expr) {
- isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
+ unsigned pos = expr.getPosition();
+
+ // This dimension is tiled if the tile size is larger than zero and not
+ // equal to its domain size (if statically known).
+ std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
+ if (tileSize && !sizeBounds.empty()) {
+ std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
+ if (sizeBound && *sizeBound == *tileSize) {
+ return;
+ }
+ }
+
+ isTiled |= !isZeroIndex(tileSizes[pos]);
}
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
visit(expr.getLHS());
@@ -70,24 +83,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
}
bool isTiled = false;
ArrayRef<OpFoldResult> tileSizes;
+ ArrayRef<OpFoldResult> sizeBounds;
};
} // namespace
-static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
+static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> sizeBounds) {
if (!expr)
return false;
- TileCheck t(tileSizes);
+ TileCheck t(tileSizes, sizeBounds);
t.visit(expr);
return t.isTiled;
}
// Checks whether the `map varies with respect to a non-zero `tileSize`.
-static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
+static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> sizeBounds) {
if (!map)
return false;
for (unsigned r = 0; r < map.getNumResults(); ++r)
- if (isTiled(map.getResult(r), tileSizes))
+ if (isTiled(map.getResult(r), tileSizes, sizeBounds))
return true;
return false;
}
@@ -581,7 +597,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
sliceParams.strides.reserve(rank);
for (unsigned r = 0; r < rank; ++r) {
LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
- if (!isTiled(map.getSubMap({r}), tileSizes)) {
+ if (!isTiled(map.getSubMap({r}), tileSizes, ubs)) {
sliceParams.offsets.push_back(builder.getIndexAttr(0));
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
sliceParams.sizes.push_back(dim);
@@ -781,10 +797,9 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
// transformations such as padding and bufferization since the
// extract/insert slice pairs make the accessed iteration argument
// subdomains explicit.
-
Type operandType = opOperand.get().getType();
- if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
- linalgOp.isDpsInit(&opOperand))) {
+ if (!isTiled(map, tileSizes, {}) && !(isa<RankedTensorType>(operandType) &&
+ linalgOp.isDpsInit(&opOperand))) {
allSliceParams.push_back(std::nullopt);
LLVM_DEBUG(llvm::dbgs()
<< ": not tiled: use shape: " << operandType << "\n");
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 557233d8aa3ec..b179a5966cc08 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -214,3 +214,35 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func @non_monotonic_affine_expr
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<7xf32>
+func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %arg0, %c0 : tensor<7xf32>
+ %empty = tensor.empty() : tensor<7xf32>
+
+ // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
+ // CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
+ // CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
+ %generic = linalg.generic
+ {indexing_maps = [affine_map<(d0) -> (d0 mod 4)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0: tensor<7xf32>)
+ outs(%empty : tensor<7xf32>) {
+ ^bb0(%in : f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<7xf32>
+ return %generic : tensor<7xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop = transform.structured.tile_using_for %0 tile_sizes [7] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c..525561b8ea2f5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -555,11 +555,12 @@ module {
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
+ // CHECK: %[[T3:.*]] = linalg.generic {{.*}}
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
%8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
- // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ // CHECK: tensor.parallel_insert_slice %[[T3]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/127197
More information about the Mlir-commits
mailing list