[Mlir-commits] [mlir] [MLIR][LINALG] Address a TODO in Bubble up extract slice (PR #79078)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 22 15:54:34 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

<details>
<summary>Changes</summary>

This change extends the optimization
 `Bubble up extract_slice above Linalg operation`
in cases where linalg op has more than one user.

---
Full diff: https://github.com/llvm/llvm-project/pull/79078.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp (+17-13) 
- (added) mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir (+85) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 428422e6e875a27..beeef7378fd3d80 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -42,7 +42,9 @@ namespace {
 /// ```
 ///
 /// This results in the reduce computation of the linalg operation.
-///
+/// In case linalg op has multiple uses we optimize only only if
+/// each use is a small portion of the result i.e. each use is an
+/// extract_slice.
 struct BubbleUpExtractSliceOpPattern
     : OpRewritePattern<tensor::ExtractSliceOp> {
   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -56,13 +58,6 @@ struct BubbleUpExtractSliceOpPattern
                                          "expected source to be linalg op");
     }
 
-    // TODO: we might relax this if we want heuristics to detect that all uses
-    // are small portion of the output.
-    if (!linalgOp->hasOneUse()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "expected single use of linalg op");
-    }
-
     if (linalgOp.getNumDpsInits() != 1) {
       return rewriter.notifyMatchFailure(sliceOp,
                                          "expected single output of linalg op");
@@ -73,11 +68,20 @@ struct BubbleUpExtractSliceOpPattern
                                          "expected tensor of linalg op");
     }
 
-    if (!sliceOp.hasUnitStride())
-      return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
-
-    if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
-      return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
+    // Check that all uses (including sliceOp) are small portion of output
+    // and satisfy the constraints.
+    for (Operation *user : linalgOp->getResult(0).getUsers()) {
+      auto sliceOpOther = dyn_cast<tensor::ExtractSliceOp>(user);
+      if (!sliceOpOther)
+        return rewriter.notifyMatchFailure(sliceOp,
+                                           "expected single use of linalg op");
+      if (!sliceOpOther.hasUnitStride())
+        return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
+
+      if (sliceOpOther.getType().getRank() !=
+          sliceOpOther.getSourceType().getRank())
+        return rewriter.notifyMatchFailure(sliceOp,
+                                           "expected no rank reduction");
     }
 
     OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
new file mode 100644
index 000000000000000..cd5aa0feafc8425
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-multi-use.mlir
@@ -0,0 +1,85 @@
+//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s
+
+func.func @multi_extract_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>,
+                   %arg2: index, %arg3: index, %arg4: index, %arg5:index,
+                   %arg6: index, %arg7: index, %arg8: index, %arg9:index
+                   ) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+
+  %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1]
+    : tensor<?x?xf32> to tensor<?x?xf32>
+
+  %2 = tensor.extract_slice %0 [%arg6, %arg7] [%arg8, %arg9] [1, 1]
+    : tensor<?x?xf32> to tensor<?x?xf32>
+
+  %3 = tensor.concat dim(0) %1, %2 :
+    (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  return %3 : tensor<?x?xf32>
+}
+//      CHECK: func @multi_extract_slice
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[GENERIC_0:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>)
+//
+//      CHECK: %[[SLICE3:.+]] = tensor.extract_slice %arg0[%arg6, %arg7] [%arg8, %arg9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[SLICE4:.+]] = tensor.extract_slice %arg1[%arg7] [%arg9] [1] : tensor<?xf32> to tensor<?xf32>
+//      CHECK: %[[SLICE5:.+]] = tensor.extract_slice %arg0[%arg6, %arg7] [%arg8, %arg9] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[GENERIC_1:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE3]], %[[SLICE4]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE5]] : tensor<?x?xf32>)
+//
+// CHECK: %[[CONCAT:.+]] = tensor.concat dim(0) %[[GENERIC_0]], %[[GENERIC_1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+//      CHECK: return %[[CONCAT]] : tensor<?x?xf32>
+
+//-----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+module {
+  func.func @multi_mixed_use(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>,
+                     %arg3: tensor<f32>, %arg4: index, %arg5: index, %arg6: index,
+                     %arg7: index) -> tensor<f32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+           ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+           outs(%arg0 : tensor<?x?xf32>) {
+      ^bb0(%in: f32, %in_3: f32, %out: f32):
+        %2 = arith.addf %in, %in_3 : f32
+        linalg.yield %2 : f32
+      } -> tensor<?x?xf32>
+    %extract = tensor.extract_slice %0[%arg4, %arg5] [%arg6, %arg7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+
+    %reduced = linalg.reduce { arith.addf } ins(%extract : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+    %reduced_0 = linalg.reduce { arith.addf } ins(%reduced : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+
+    %reduced_1 = linalg.reduce { arith.addf } ins(%0 : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+    %reduced_2 = linalg.reduce { arith.addf } ins(%reduced_1 : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+
+    %1 = arith.divf %reduced_0, %reduced_2 : tensor<f32>
+    return %1 : tensor<f32>
+  }
+}
+
+// CHECK: func @multi_mixed_use
+// CHECK: %[[GENERIC:.+]] = linalg.generic  {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]][%arg4, %arg5] [%arg6, %arg7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//
+// CHECK: %[[REDUCED:.+]] = linalg.reduce { arith.addf } ins(%[[EXTRACT]] : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+// CHECK: %[[REDUCED_0:.+]] = linalg.reduce { arith.addf } ins(%[[REDUCED]] : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+//
+// CHECK: %[[REDUCED_1:.+]] = linalg.reduce { arith.addf } ins(%[[GENERIC]] : tensor<?x?xf32>) outs(%arg2 : tensor<?xf32>) dimensions = [1]
+// CHECK: %[[REDUCED_2:.+]] = linalg.reduce { arith.addf } ins(%[[REDUCED_1]] : tensor<?xf32>) outs(%arg3 : tensor<f32>) dimensions = [0]
+//
+// CHECK: %[[DIV:.+]] = arith.divf %[[REDUCED_0]], %[[REDUCED_2]] : tensor<f32>
+// CHECK: return %[[DIV]] : tensor<f32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/79078


More information about the Mlir-commits mailing list