[Mlir-commits] [mlir] Full slices when tiling full loop trip count (PR #127197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 14 03:09:40 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (josel-amd)

<details>
<summary>Changes</summary>

When tiling a chain of linalg.ops, we can only set the tile sizes of the first one to 0 to say untiled, but producers of it will get a tile size of <loop trip count>. We must return the full slice in those cases because the code that computes the slices sizes in the general case doesn't handle non-monotonic affine expressions. Otherwise we would generate invalid code for non-monotonic expressions even if all involved dimensions are effectively untiled.

---
Full diff: https://github.com/llvm/llvm-project/pull/127197.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+6-3) 
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+25-10) 
- (modified) mlir/test/Dialect/Linalg/tile-tensors.mlir (+32) 
- (modified) mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (+2-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index b7764da26a7f4..a838b99c9dbb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -115,13 +115,16 @@ struct LinalgOpTilingInterface
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
-    // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
-    // specified could lead to out of bounds accesses.
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
+    SmallVector<OpFoldResult> allShapeSizes =
+        linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
+    SmallVector<OpFoldResult> sizeBounds =
+        mlir::affine::makeComposedFoldedMultiResultAffineApply(
+            b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
     SmallVector<Value> valuesToTile = linalgOp->getOperands();
     SmallVector<Value> tiledOperands = makeTiledShapes(
-        b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+        b, loc, linalgOp, valuesToTile, offsets, sizes, sizeBounds, true);
     SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
         llvm::make_filter_range(
             tiledOperands,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d148067fe6343..3f0382f4bc8b6 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -56,10 +56,23 @@ namespace {
 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
 //
 struct TileCheck : public AffineExprVisitor<TileCheck> {
-  TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
+  TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
+      : tileSizes(tileSizes), sizeBounds(sizeBounds) {}
 
   void visitDimExpr(AffineDimExpr expr) {
-    isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
+    unsigned pos = expr.getPosition();
+
+    // This dimension is tiled if the tile size is larger than zero and not
+    // equal to its domain size (if statically known).
+    std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
+    if (tileSize && !sizeBounds.empty()) {
+      std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
+      if (sizeBound && *sizeBound == *tileSize) {
+        return;
+      }
+    }
+
+    isTiled |= !isZeroIndex(tileSizes[pos]);
   }
   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
     visit(expr.getLHS());
@@ -70,24 +83,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
   }
   bool isTiled = false;
   ArrayRef<OpFoldResult> tileSizes;
+  ArrayRef<OpFoldResult> sizeBounds;
 };
 
 } // namespace
 
-static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
+static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
+                    ArrayRef<OpFoldResult> sizeBounds) {
   if (!expr)
     return false;
-  TileCheck t(tileSizes);
+  TileCheck t(tileSizes, sizeBounds);
   t.visit(expr);
   return t.isTiled;
 }
 
 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
-static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
+static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
+                    ArrayRef<OpFoldResult> sizeBounds) {
   if (!map)
     return false;
   for (unsigned r = 0; r < map.getNumResults(); ++r)
-    if (isTiled(map.getResult(r), tileSizes))
+    if (isTiled(map.getResult(r), tileSizes, sizeBounds))
       return true;
   return false;
 }
@@ -581,7 +597,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
   sliceParams.strides.reserve(rank);
   for (unsigned r = 0; r < rank; ++r) {
     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
-    if (!isTiled(map.getSubMap({r}), tileSizes)) {
+    if (!isTiled(map.getSubMap({r}), tileSizes, ubs)) {
       sliceParams.offsets.push_back(builder.getIndexAttr(0));
       OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
       sliceParams.sizes.push_back(dim);
@@ -781,10 +797,9 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
     // transformations such as padding and bufferization since the
     // extract/insert slice pairs make the accessed iteration argument
     // subdomains explicit.
-
     Type operandType = opOperand.get().getType();
-    if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
-                                      linalgOp.isDpsInit(&opOperand))) {
+    if (!isTiled(map, tileSizes, {}) && !(isa<RankedTensorType>(operandType) &&
+                                          linalgOp.isDpsInit(&opOperand))) {
       allSliceParams.push_back(std::nullopt);
       LLVM_DEBUG(llvm::dbgs()
                  << ": not tiled: use shape: " << operandType << "\n");
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 557233d8aa3ec..b179a5966cc08 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -214,3 +214,35 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @non_monotonic_affine_expr
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<7xf32>
+func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<7xf32>
+  %empty = tensor.empty() : tensor<7xf32>
+
+  // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
+  // CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
+  // CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
+  %generic = linalg.generic
+    {indexing_maps = [affine_map<(d0) -> (d0 mod 4)>,
+                      affine_map<(d0) -> (d0)>],
+     iterator_types = ["parallel"]}
+    ins(%arg0: tensor<7xf32>)
+    outs(%empty : tensor<7xf32>) {
+    ^bb0(%in : f32, %out: f32):
+      linalg.yield %in : f32
+    } -> tensor<7xf32>
+  return %generic : tensor<7xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop = transform.structured.tile_using_for %0 tile_sizes [7] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c..525561b8ea2f5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -555,11 +555,12 @@ module {
 
       // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
       // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
+      // CHECK: %[[T3:.*]] = linalg.generic {{.*}}
       %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
 
       %8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
       scf.forall.in_parallel {
-        // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+        // CHECK: tensor.parallel_insert_slice %[[T3]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
         tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
       }
     }

``````````

</details>


https://github.com/llvm/llvm-project/pull/127197


More information about the Mlir-commits mailing list