[Mlir-commits] [mlir] [mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible (PR #113819)

Aviad Cohen llvmlistbot at llvm.org
Sun Oct 27 08:53:28 PDT 2024


https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/113819

* During operation tiling using scf, we may avoid inserting affine.min to handle the last tile where `upper_bound = step * k` where stride may be a constant or a dynamic.

>From d88f7490218ac56df2176b17decb3d4249a43373 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 27 Oct 2024 14:22:47 +0200
Subject: [PATCH] [mlir][scf]: Avoid inserting affine.min when tiling dynamic
 operation sizes if possible

* During operation tiling using scf, we may avoid inserting affine.min
  to handle the last tile where `upper_bound = step * k` where stride may
be a constant or a dynamic.
---
 .../SCF/Transforms/TileUsingInterface.cpp     | 24 ++++++++++
 .../Dialect/Linalg/transform-op-tile.mlir     | 47 +++++++++++++++++++
 2 files changed, 71 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..54afbe4b032bd4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -186,11 +186,35 @@ static void checkSafeToTileToForall(TilingInterface op,
   }
 }
 
+/// Returns true if `size` is dynamic multiplication of `stride`.
+/// i.e. , `size = stride * k` where stride may be a constant or a dynamic.
+static bool dynamiclyDivisible(OpFoldResult size, OpFoldResult stride) {
+  Value dynamicSize = dyn_cast_if_present<Value>(size);
+  if (!dynamicSize)
+    return false;
+  auto mulOp = dynamicSize.getDefiningOp<arith::MulIOp>();
+  if (!mulOp)
+    return false;
+  if (Value dynamicStride = dyn_cast_if_present<Value>(stride))
+    return mulOp.getLhs() == dynamicStride || mulOp.getRhs() == dynamicStride;
+  std::optional<int64_t> strideAsInt = getConstantIntValue(stride);
+  std::optional<int64_t> lhsAsInt = getConstantIntValue(mulOp.getLhs());
+  std::optional<int64_t> rhsAsInt = getConstantIntValue(mulOp.getRhs());
+  if (strideAsInt && lhsAsInt && *strideAsInt == *lhsAsInt)
+    return true;
+  if (strideAsInt && rhsAsInt && *strideAsInt == *rhsAsInt)
+    return true;
+  
+  return false;
+}
+
 /// Check if `stride` evenly divides the trip count `size - offset`.
 static bool tileDividesIterationDomain(Range loopRange) {
   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
   if (!offsetAsInt)
     return false;
+  if (*offsetAsInt == 0 && dynamiclyDivisible(loopRange.size, loopRange.stride))
+    return true;
   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
   if (!sizeAsInt)
     return false;
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 7bac850d0b7fe9..ade523ef378f36 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
     -> tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: splited_dynamic_linalg_generic
+func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
+  %c80 = arith.constant 80 : index
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg1, %c0 : tensor<?xi16>
+  %0 = tensor.empty(%dim) : tensor<?xi16>
+  %1 = arith.divui %dim, %c80 : index
+  %2 = arith.muli %1, %c80 : index
+  %3 = arith.remui %dim, %c80 : index
+  %extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+  // CHECK: scf.for
+  // CHECK-NOT: affine.min
+  %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
+  ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+    %6 = arith.addi %in_1, %in_2 : i16
+    linalg.yield %6 : i16
+  } -> tensor<?xi16>
+  %inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
+  %extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+  %extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+  // CHECK-NOT: scf.for
+  %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
+  ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+    %7 = arith.addi %in_1, %in_2 : i16
+    linalg.yield %7 : i16
+  } -> tensor<?xi16>
+  %inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
+  return %inserted_slice_0 : tensor<?xi16>
+}
+
+
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+}
+}



More information about the Mlir-commits mailing list