[Mlir-commits] [mlir] Extend `getBackwardSlice` to track values captured from above (PR #113478)

Ian Wood llvmlistbot at llvm.org
Tue Oct 29 20:52:05 PDT 2024


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/113478

>From 29209330d68e877cb92eca9f48b782244a97ecf9 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Wed, 23 Oct 2024 16:46:02 +0000
Subject: [PATCH 1/2] Add `omitUsesFromAbove` to getBackwardsSlice

`getBackwardsSlice` should track values captured by each op's region
that it traverses, and follow those defs.

However, there might be logic that depends on not traversing captured
values so this change preserves the default behavior by hiding this
logic behind the `omitUsesFromAbove` flag.
---
 mlir/include/mlir/Analysis/SliceAnalysis.h |  5 +++++
 mlir/lib/Analysis/SliceAnalysis.cpp        | 14 ++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 99279fdfe427c8..a4f5d937cd51da 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -47,6 +47,11 @@ struct BackwardSliceOptions : public SliceOptions {
   /// backward slice computation traverses block arguments and asserts that the
   /// parent op has a single region with a single block.
   bool omitBlockArguments = false;
+
+  /// When omitUsesFromAbove is true, the backward slice computation omits
+  /// traversing values that are captured from above.
+  /// TODO: this should default to `false` after users have been updated.
+  bool omitUsesFromAbove = true;
 };
 
 using ForwardSliceOptions = SliceOptions;
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 2b1cf411ceeeeb..d07ae7b3ffa2c7 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -115,6 +116,19 @@ static void getBackwardSliceImpl(Operation *op,
     }
   }
 
+  // Visit values that are defined above.
+  if (!options.omitUsesFromAbove) {
+    visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
+      if (Operation *definingOp = operand->get().getDefiningOp()) {
+        getBackwardSliceImpl(definingOp, backwardSlice, options);
+        return;
+      }
+      Operation *bbAargOwner =
+          cast<BlockArgument>(operand->get()).getOwner()->getParentOp();
+      getBackwardSliceImpl(bbAargOwner, backwardSlice, options);
+    });
+  }
+
   backwardSlice->insert(op);
 }
 

>From 3e8b47a6e9d7153baa70838d796f1593a12984d8 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Wed, 30 Oct 2024 03:51:23 +0000
Subject: [PATCH 2/2] Cleanup impl & add test

---
 mlir/lib/Analysis/SliceAnalysis.cpp | 21 +++++++--------------
 mlir/test/IR/slice.mlir             | 28 +++++++++++++++++++++++++++-
 mlir/test/lib/IR/TestSlicing.cpp    |  2 ++
 3 files changed, 36 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index d07ae7b3ffa2c7..cd0dc25adf1ca4 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -92,7 +92,13 @@ static void getBackwardSliceImpl(Operation *op,
   if (options.filter && !options.filter(op))
     return;
 
-  for (const auto &en : llvm::enumerate(op->getOperands())) {
+  auto operands = op->getOperands();
+  SetVector<Value> valuesToFollow(operands.begin(), operands.end());
+  if (!options.omitUsesFromAbove) {
+    getUsedValuesDefinedAbove(op->getRegions(), valuesToFollow);
+  }
+
+  for (const auto &en : llvm::enumerate(valuesToFollow)) {
     auto operand = en.value();
     if (auto *definingOp = operand.getDefiningOp()) {
       if (backwardSlice->count(definingOp) == 0)
@@ -116,19 +122,6 @@ static void getBackwardSliceImpl(Operation *op,
     }
   }
 
-  // Visit values that are defined above.
-  if (!options.omitUsesFromAbove) {
-    visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
-      if (Operation *definingOp = operand->get().getDefiningOp()) {
-        getBackwardSliceImpl(definingOp, backwardSlice, options);
-        return;
-      }
-      Operation *bbAargOwner =
-          cast<BlockArgument>(operand->get()).getOwner()->getParentOp();
-      getBackwardSliceImpl(bbAargOwner, backwardSlice, options);
-    });
-  }
-
   backwardSlice->insert(op);
 }
 
diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
index 0a32a0f231baf2..87d446c8f415af 100644
--- a/mlir/test/IR/slice.mlir
+++ b/mlir/test/IR/slice.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
+// RUN: mlir-opt -slice-analysis-test -split-input-file %s | FileCheck %s
 
 func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
   %a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
@@ -33,3 +33,29 @@ func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
 //   CHECK-DAG:   %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
 //   CHECK-DAG:   %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
 //       CHECK:   return
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %c2 = arith.constant 2 : index
+    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+    %2 = arith.addf %extracted, %extracted : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  return
+}
+
+// CHECK-LABEL: func @slice_use_from_above__backward_slice__0
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor 
+//       CHECK:   %[[A:.+]] = linalg.generic {{.*}} ins(%[[ARG0]]
+//       CHECK:   %[[B:.+]] = tensor.collapse_shape %[[A]]
+//       CHECK:   return
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
index c3d0d151c6d24d..e99d5976d6d9df 100644
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -39,6 +39,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
   SetVector<Operation *> slice;
   BackwardSliceOptions options;
   options.omitBlockArguments = omitBlockArguments;
+  // TODO: Make this default.
+  options.omitUsesFromAbove = false;
   getBackwardSlice(op, &slice, options);
   for (Operation *slicedOp : slice)
     builder.clone(*slicedOp, mapper);



More information about the Mlir-commits mailing list