[Mlir-commits] [mlir] [mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible (PR #113819)

Aviad Cohen llvmlistbot at llvm.org
Tue Oct 29 05:42:53 PDT 2024

https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/113819

>From 7a5ab7c53ce3d7e9ec1408360e3037758d33d897 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 27 Oct 2024 14:22:47 +0200
Subject: [PATCH] [mlir][scf]: Avoid inserting affine.min when tiling dynamic
 operation sizes if possible

* During operation tiling using scf, we may avoid inserting affine.min
  to handle the last tile where `upper_bound = step * k` where stride may
be a constant or a dynamic.
 .../SCF/Transforms/TileUsingInterface.cpp     | 49 ++++++++++++++++---
 .../Dialect/Linalg/transform-op-tile.mlir     | 47 ++++++++++++++++++
 2 files changed, 88 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..ecb7c265305bd8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -21,11 +21,13 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -186,18 +188,49 @@ static void checkSafeToTileToForall(TilingInterface op,
+/// Collect divider of the `ofr`.
+static void collectDividers(OpFoldResult ofr,
+                            SmallVector<OpFoldResult> &dividers) {
+  dividers.push_back(ofr);
+  if (ofr.is<Attribute>())
+    return;
+  auto mulOp = cast<Value>(ofr).getDefiningOp<arith::MulIOp>();
+  if (!mulOp)
+    return;
+  // Given `ofr` = `x` * `y`, all dividers of `x` and `y` are dividers of `ofr`.
+  collectDividers(mulOp.getLhs(), dividers);
+  collectDividers(mulOp.getRhs(), dividers);
 /// Check if `stride` evenly divides the trip count `size - offset`.
 static bool tileDividesIterationDomain(Range loopRange) {
+  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
-  if (!offsetAsInt)
-    return false;
   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
-  if (!sizeAsInt)
-    return false;
-  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
-  if (!strideAsInt)
-    return false;
-  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
+  if (strideAsInt && offsetAsInt && sizeAsInt)
+    // `stride`/`size`/`offset` are static, checking (size - offset) % stride =
+    // 0.
+    return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() ==
+            0);
+  // At least `stride`/`size`/`offset` is dynamic.
+  SmallVector<OpFoldResult> dividersOfSize, dividersOfOffset;
+  collectDividers(loopRange.size, dividersOfSize);
+  collectDividers(loopRange.offset, dividersOfOffset);
+  // Return true if `stride` divides one of the dividers of both `size` and
+  // `offset`.
+  auto isStrideDividesDivider = [&](OpFoldResult divider) {
+    if (!strideAsInt)
+      // `stride` is dynamic.
+      return divider == loopRange.stride;
+    std::optional<int64_t> dividerAsInt = getConstantIntValue(divider);
+    return dividerAsInt && *dividerAsInt % *strideAsInt == 0;
+  };
+  return llvm::any_of(dividersOfSize, isStrideDividesDivider) &&
+         llvm::any_of(dividersOfOffset, isStrideDividesDivider);
 /// Returns the bounded tile size given the current `offset`, `loopRange` and
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 7bac850d0b7fe9..ade523ef378f36 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
     -> tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
+// -----
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: splited_dynamic_linalg_generic
+func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
+  %c80 = arith.constant 80 : index
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg1, %c0 : tensor<?xi16>
+  %0 = tensor.empty(%dim) : tensor<?xi16>
+  %1 = arith.divui %dim, %c80 : index
+  %2 = arith.muli %1, %c80 : index
+  %3 = arith.remui %dim, %c80 : index
+  %extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+  // CHECK: scf.for
+  // CHECK-NOT: affine.min
+  %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
+  ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+    %6 = arith.addi %in_1, %in_2 : i16
+    linalg.yield %6 : i16
+  } -> tensor<?xi16>
+  %inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
+  %extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+  // CHECK-NOT: scf.for
+  %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
+  ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+    %7 = arith.addi %in_1, %in_2 : i16
+    linalg.yield %7 : i16
+  } -> tensor<?xi16>
+  %inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
+  return %inserted_slice_0 : tensor<?xi16>
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield

More information about the Mlir-commits mailing list