[Mlir-commits] [mlir] [mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible (PR #113819)
Aviad Cohen
llvmlistbot at llvm.org
Tue Oct 29 05:38:15 PDT 2024
https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/113819
>From d90fa69413be0f8289ba8098187520e0a1549a7b Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 27 Oct 2024 14:22:47 +0200
Subject: [PATCH] [mlir][scf]: Avoid inserting affine.min when tiling dynamic
operation sizes if possible
* During operation tiling using scf, we may avoid inserting affine.min
to handle the last tile where `upper_bound = step * k` where stride may
be a constant or a dynamic.
---
.../SCF/Transforms/TileUsingInterface.cpp | 45 ++++++++++++++----
.../Dialect/Linalg/transform-op-tile.mlir | 47 +++++++++++++++++++
2 files changed, 84 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..4f0a469f138007 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -21,11 +21,13 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -186,18 +188,45 @@ static void checkSafeToTileToForall(TilingInterface op,
}
}
+/// Collect divider of the `ofr`.
+static void collectDividers(OpFoldResult ofr, SmallVector<OpFoldResult> ÷rs) {
+ dividers.push_back(ofr);
+ if (ofr.is<Attribute>())
+ return;
+ auto mulOp = cast<Value>(ofr).getDefiningOp<arith::MulIOp>();
+ if (!mulOp)
+ return;
+
+ // Given `ofr` = `x` * `y`, all dividers of `x` and `y` are dividers of `ofr`.
+ collectDividers(mulOp.getLhs(), dividers);
+ collectDividers(mulOp.getRhs(), dividers);
+}
+
/// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
+ std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
- if (!offsetAsInt)
- return false;
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
- if (!sizeAsInt)
- return false;
- std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
- if (!strideAsInt)
- return false;
- return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
+ if (strideAsInt && offsetAsInt && sizeAsInt)
+ // `stride`/`size`/`offset` are static, checking (size - offset) % stride = 0.
+ return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
+
+ // At least `stride`/`size`/`offset` is dynamic.
+ SmallVector<OpFoldResult> dividersOfSize, dividersOfOffset;
+ collectDividers(loopRange.size, dividersOfSize);
+ collectDividers(loopRange.offset, dividersOfOffset);
+
+ // Return true if `stride` divides one of the dividers of both `size` and `offset`.
+ auto isStrideDividesDivider = [&](OpFoldResult divider) {
+ if (!strideAsInt)
+ // `stride` is dynamic.
+ return divider == loopRange.stride;
+
+ std::optional<int64_t> dividerAsInt = getConstantIntValue(divider);
+ return dividerAsInt && *dividerAsInt % *strideAsInt == 0;
+ };
+ return llvm::any_of(dividersOfSize, isStrideDividesDivider) &&
+ llvm::any_of(dividersOfOffset, isStrideDividesDivider);
}
/// Returns the bounded tile size given the current `offset`, `loopRange` and
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 7bac850d0b7fe9..ade523ef378f36 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
-> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: splited_dynamic_linalg_generic
+func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
+ %c80 = arith.constant 80 : index
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg1, %c0 : tensor<?xi16>
+ %0 = tensor.empty(%dim) : tensor<?xi16>
+ %1 = arith.divui %dim, %c80 : index
+ %2 = arith.muli %1, %c80 : index
+ %3 = arith.remui %dim, %c80 : index
+ %extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+ // CHECK: scf.for
+ // CHECK-NOT: affine.min
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
+ ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+ %6 = arith.addi %in_1, %in_2 : i16
+ linalg.yield %6 : i16
+ } -> tensor<?xi16>
+ %inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+ // CHECK-NOT: scf.for
+ %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
+ ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+ %7 = arith.addi %in_1, %in_2 : i16
+ linalg.yield %7 : i16
+ } -> tensor<?xi16>
+ %inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
+ return %inserted_slice_0 : tensor<?xi16>
+}
+
+
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+}
+}
More information about the Mlir-commits
mailing list