[Mlir-commits] [mlir] [mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible (PR #113819)
Aviad Cohen
llvmlistbot at llvm.org
Sun Oct 27 08:53:28 PDT 2024
https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/113819
* During operation tiling using scf, we may avoid inserting affine.min to handle the last tile where `upper_bound = step * k` where stride may be a constant or a dynamic.
>From d88f7490218ac56df2176b17decb3d4249a43373 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 27 Oct 2024 14:22:47 +0200
Subject: [PATCH] [mlir][scf]: Avoid inserting affine.min when tiling dynamic
operation sizes if possible
* During operation tiling using scf, we may avoid inserting affine.min
to handle the last tile where `upper_bound = step * k` where stride may
be a constant or a dynamic.
---
.../SCF/Transforms/TileUsingInterface.cpp | 24 ++++++++++
.../Dialect/Linalg/transform-op-tile.mlir | 47 +++++++++++++++++++
2 files changed, 71 insertions(+)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..54afbe4b032bd4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -186,11 +186,35 @@ static void checkSafeToTileToForall(TilingInterface op,
}
}
+/// Returns true if `size` is dynamic multiplication of `stride`.
+/// i.e. , `size = stride * k` where stride may be a constant or a dynamic.
+static bool dynamiclyDivisible(OpFoldResult size, OpFoldResult stride) {
+ Value dynamicSize = dyn_cast_if_present<Value>(size);
+ if (!dynamicSize)
+ return false;
+ auto mulOp = dynamicSize.getDefiningOp<arith::MulIOp>();
+ if (!mulOp)
+ return false;
+ if (Value dynamicStride = dyn_cast_if_present<Value>(stride))
+ return mulOp.getLhs() == dynamicStride || mulOp.getRhs() == dynamicStride;
+ std::optional<int64_t> strideAsInt = getConstantIntValue(stride);
+ std::optional<int64_t> lhsAsInt = getConstantIntValue(mulOp.getLhs());
+ std::optional<int64_t> rhsAsInt = getConstantIntValue(mulOp.getRhs());
+ if (strideAsInt && lhsAsInt && *strideAsInt == *lhsAsInt)
+ return true;
+ if (strideAsInt && rhsAsInt && *strideAsInt == *rhsAsInt)
+ return true;
+
+ return false;
+}
+
/// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
if (!offsetAsInt)
return false;
+ if (*offsetAsInt == 0 && dynamiclyDivisible(loopRange.size, loopRange.stride))
+ return true;
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
if (!sizeAsInt)
return false;
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 7bac850d0b7fe9..ade523ef378f36 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
-> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: splited_dynamic_linalg_generic
+func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
+ %c80 = arith.constant 80 : index
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg1, %c0 : tensor<?xi16>
+ %0 = tensor.empty(%dim) : tensor<?xi16>
+ %1 = arith.divui %dim, %c80 : index
+ %2 = arith.muli %1, %c80 : index
+ %3 = arith.remui %dim, %c80 : index
+ %extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
+ // CHECK: scf.for
+ // CHECK-NOT: affine.min
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
+ ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+ %6 = arith.addi %in_1, %in_2 : i16
+ linalg.yield %6 : i16
+ } -> tensor<?xi16>
+ %inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+ %extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
+ // CHECK-NOT: scf.for
+ %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
+ ^bb0(%in_1: i16, %in_2: i16, %out: i16):
+ %7 = arith.addi %in_1, %in_2 : i16
+ linalg.yield %7 : i16
+ } -> tensor<?xi16>
+ %inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
+ return %inserted_slice_0 : tensor<?xi16>
+}
+
+
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+}
+}
More information about the Mlir-commits
mailing list