[Mlir-commits] [mlir] [MLIR][LINALG] Address a TODO in Bubble up extract slice (PR #79078)
Javed Absar
llvmlistbot at llvm.org
Mon Jan 22 16:01:56 PST 2024
https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/79078
>From bb5e500bd9f1f1aa8c1ac678901928de3893610b Mon Sep 17 00:00:00 2001
From: mabsar <javed.absar at gmail.com>
Date: Mon, 22 Jan 2024 12:34:59 -0500
Subject: [PATCH 1/2] [MLIR][LINALG] Address a TODO in Bubble up extract slice
This change extends the optimization
`Bubble up extract_slice above Linalg operation`
in cases where linalg op has moer than one user.
---
.../Transforms/BubbleUpExtractSlice.cpp | 30 ++++---
.../bubble-up-extract-slice-multi-use.mlir | 85 +++++++++++++++++++
2 files changed, 102 insertions(+), 13 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 428422e6e875a27..beeef7378fd3d80 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -42,7 +42,9 @@ namespace {
/// ```
///
/// This results in the reduce computation of the linalg operation.
-///
+/// In case linalg op has multiple uses we optimize only only if
+/// each use is a small portion of the result i.e. each use is an
+/// extract_slice.
struct BubbleUpExtractSliceOpPattern
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -56,13 +58,6 @@ struct BubbleUpExtractSliceOpPattern
"expected source to be linalg op");
}
- // TODO: we might relax this if we want heuristics to detect that all uses
- // are small portion of the output.
- if (!linalgOp->hasOneUse()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "expected single use of linalg op");
- }
-
if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single output of linalg op");
@@ -73,11 +68,20 @@ struct BubbleUpExtractSliceOpPattern
"expected tensor of linalg op");
}
- if (!sliceOp.hasUnitStride())
- return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
-
- if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
- return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
+ // Check that all uses (including sliceOp) are small portion of output
+ // and satisfy the constraints.
+ for (Operation *user : linalgOp->getResult(0).getUsers()) {
+ auto sliceOpOther = dyn_cast<tensor::ExtractSliceOp>(user);
+ if (!sliceOpOther)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "expected single use of linalg op");
+ if (!sliceOpOther.hasUnitStride())
+ return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
+
+ if (sliceOpOther.getType().getRank() !=
+ sliceOpOther.getSourceType().getRank())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "expected no rank reduction");
}
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
new file mode 100644
index 000000000000000..cd5aa0feafc8425
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
@@ -0,0 +1,85 @@
+//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s
+
+func.func @multi_extract_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>,
+ %arg2: index, %arg3: index, %arg4: index, %arg5:index,
+ %arg6: index, %arg7: index, %arg8: index, %arg9:index
+ ) -> tensor<?x?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %add = arith.addf %b0, %b1 : f32
+ linalg.yield %add : f32
+ } -> tensor<?x?xf32>
+
+ %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+
+ %2 = tensor.extract_slice %0 [%arg6, %arg7] [%arg8, %arg9] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+
+ %3 = tensor.concat dim(0) %1, %2 :
+ (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ return %3 : tensor<?x?xf32>
+}
+// CHECK: func @multi_extract_slice
+// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[GENERIC_0:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>)
+//
+// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %arg0[%arg6, %arg7] [%arg8, %arg9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[SLICE4:.+]] = tensor.extract_slice %arg1[%arg7] [%arg9] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[SLICE5:.+]] = tensor.extract_slice %arg0[%arg6, %arg7] [%arg8, %arg9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[GENERIC_1:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE3]], %[[SLICE4]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE5]] : tensor<?x?xf32>)
+//
+// CHECK: %[[CONCAT:.+]] = tensor.concat dim(0) %[[GENERIC_0]], %[[GENERIC_1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
+
+//-----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+module {
+ func.func @multi_mixed_use(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>,
+ %arg3: tensor<f32>, %arg4: index, %arg5: index, %arg6: index,
+ %arg7: index) -> tensor<f32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_3: f32, %out: f32):
+ %2 = arith.addf %in, %in_3 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ %extract = tensor.extract_slice %0[%arg4, %arg5] [%arg6, %arg7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+
+ %reduced = linalg.reduce { arith.addf } ins(%extract : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+ %reduced_0 = linalg.reduce { arith.addf } ins(%reduced : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+
+ %reduced_1 = linalg.reduce { arith.addf } ins(%0 : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+ %reduced_2 = linalg.reduce { arith.addf } ins(%reduced_1 : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+
+ %1 = arith.divf %reduced_0, %reduced_2 : tensor<f32>
+ return %1 : tensor<f32>
+ }
+}
+
+// CHECK: func @multi_mixed_use
+// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]][%arg4, %arg5] [%arg6, %arg7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[EXTRACT]] : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+// CHECK: %[[REDUCED_0:.+]] = linalg.reduce { arith.addf } ins(%[[REDUCED]] : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+//
+// CHECK: %[[REDUCED_1:.+]] = linalg.reduce { arith.addf } ins(%[[GENERIC]] : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+// CHECK: %[[REDUCED_2:.+]] = linalg.reduce { arith.addf } ins(%[[REDUCED_1]] : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+//
+// CHECK: %[[DIV:.+]] = arith.divf %[[REDUCED_0]], %[[REDUCED_2]] : tensor<f32>
+// CHECK: return %[[DIV]] : tensor<f32>
>From 41d8e26577752ddd018454dcfb4a54c5c53331f9 Mon Sep 17 00:00:00 2001
From: mabsar <javed.absar at gmail.com>
Date: Mon, 22 Jan 2024 19:01:08 -0500
Subject: [PATCH 2/2] [MLIR][Linalg] Fix comment
---
.../Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index beeef7378fd3d80..7e295f4095e726b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -42,8 +42,8 @@ namespace {
/// ```
///
/// This results in the reduce computation of the linalg operation.
-/// In case linalg op has multiple uses we optimize only only if
-/// each use is a small portion of the result i.e. each use is an
+/// In case linalg op has multiple uses we optimize only if each
+/// use is a small portion of the result i.e. each use is an
/// extract_slice.
struct BubbleUpExtractSliceOpPattern
: OpRewritePattern<tensor::ExtractSliceOp> {
@@ -76,11 +76,11 @@ struct BubbleUpExtractSliceOpPattern
return rewriter.notifyMatchFailure(sliceOp,
"expected single use of linalg op");
if (!sliceOpOther.hasUnitStride())
- return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
+ return rewriter.notifyMatchFailure(sliceOpOther, "expected unit stride");
if (sliceOpOther.getType().getRank() !=
sliceOpOther.getSourceType().getRank())
- return rewriter.notifyMatchFailure(sliceOp,
+ return rewriter.notifyMatchFailure(sliceOpOther,
"expected no rank reduction");
}
More information about the Mlir-commits
mailing list