[Mlir-commits] [mlir] [MLIR][LINALG] Address a TODO in Bubble up extract slice (PR #79078)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 22 15:54:34 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Javed Absar (javedabsar1)
<details>
<summary>Changes</summary>
This change extends the optimization
`Bubble up extract_slice above Linalg operation`
in cases where linalg op has more than one user.
---
Full diff: https://github.com/llvm/llvm-project/pull/79078.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp (+17-13)
- (added) mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir (+85)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 428422e6e875a27..beeef7378fd3d80 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -42,7 +42,9 @@ namespace {
/// ```
///
/// This results in the reduce computation of the linalg operation.
-///
+/// In case linalg op has multiple uses we optimize only only if
+/// each use is a small portion of the result i.e. each use is an
+/// extract_slice.
struct BubbleUpExtractSliceOpPattern
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -56,13 +58,6 @@ struct BubbleUpExtractSliceOpPattern
"expected source to be linalg op");
}
- // TODO: we might relax this if we want heuristics to detect that all uses
- // are small portion of the output.
- if (!linalgOp->hasOneUse()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "expected single use of linalg op");
- }
-
if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single output of linalg op");
@@ -73,11 +68,20 @@ struct BubbleUpExtractSliceOpPattern
"expected tensor of linalg op");
}
- if (!sliceOp.hasUnitStride())
- return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
-
- if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
- return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
+ // Check that all uses (including sliceOp) are small portion of output
+ // and satisfy the constraints.
+ for (Operation *user : linalgOp->getResult(0).getUsers()) {
+ auto sliceOpOther = dyn_cast<tensor::ExtractSliceOp>(user);
+ if (!sliceOpOther)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "expected single use of linalg op");
+ if (!sliceOpOther.hasUnitStride())
+ return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
+
+ if (sliceOpOther.getType().getRank() !=
+ sliceOpOther.getSourceType().getRank())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "expected no rank reduction");
}
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
new file mode 100644
index 000000000000000..cd5aa0feafc8425
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
@@ -0,0 +1,85 @@
+//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s
+
+func.func @multi_extract_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>,
+ %arg2: index, %arg3: index, %arg4: index, %arg5:index,
+ %arg6: index, %arg7: index, %arg8: index, %arg9:index
+ ) -> tensor<?x?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %add = arith.addf %b0, %b1 : f32
+ linalg.yield %add : f32
+ } -> tensor<?x?xf32>
+
+ %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+
+ %2 = tensor.extract_slice %0 [%arg6, %arg7] [%arg8, %arg9] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+
+ %3 = tensor.concat dim(0) %1, %2 :
+ (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ return %3 : tensor<?x?xf32>
+}
+// CHECK: func @multi_extract_slice
+// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[GENERIC_0:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>)
+//
+// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %arg0[%arg6, %arg7] [%arg8, %arg9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[SLICE4:.+]] = tensor.extract_slice %arg1[%arg7] [%arg9] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[SLICE5:.+]] = tensor.extract_slice %arg0[%arg6, %arg7] [%arg8, %arg9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[GENERIC_1:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE3]], %[[SLICE4]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE5]] : tensor<?x?xf32>)
+//
+// CHECK: %[[CONCAT:.+]] = tensor.concat dim(0) %[[GENERIC_0]], %[[GENERIC_1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
+
+//-----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+module {
+ func.func @multi_mixed_use(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>,
+ %arg3: tensor<f32>, %arg4: index, %arg5: index, %arg6: index,
+ %arg7: index) -> tensor<f32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_3: f32, %out: f32):
+ %2 = arith.addf %in, %in_3 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ %extract = tensor.extract_slice %0[%arg4, %arg5] [%arg6, %arg7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+
+ %reduced = linalg.reduce { arith.addf } ins(%extract : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+ %reduced_0 = linalg.reduce { arith.addf } ins(%reduced : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+
+ %reduced_1 = linalg.reduce { arith.addf } ins(%0 : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+ %reduced_2 = linalg.reduce { arith.addf } ins(%reduced_1 : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+
+ %1 = arith.divf %reduced_0, %reduced_2 : tensor<f32>
+ return %1 : tensor<f32>
+ }
+}
+
+// CHECK: func @multi_mixed_use
+// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]][%arg4, %arg5] [%arg6, %arg7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[EXTRACT]] : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+// CHECK: %[[REDUCED_0:.+]] = linalg.reduce { arith.addf } ins(%[[REDUCED]] : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+//
+// CHECK: %[[REDUCED_1:.+]] = linalg.reduce { arith.addf } ins(%[[GENERIC]] : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+// CHECK: %[[REDUCED_2:.+]] = linalg.reduce { arith.addf } ins(%[[REDUCED_1]] : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+//
+// CHECK: %[[DIV:.+]] = arith.divf %[[REDUCED_0]], %[[REDUCED_2]] : tensor<f32>
+// CHECK: return %[[DIV]] : tensor<f32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/79078
More information about the Mlir-commits
mailing list