[Mlir-commits] [mlir] 1bc58a2 - Extend `getBackwardSlice` to track values captured from above (#113478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 31 07:47:51 PDT 2024


Author: Ian Wood
Date: 2024-10-31T07:47:48-07:00
New Revision: 1bc58a258e2edb6221009a26d0f0037eda6c7c47

URL: https://github.com/llvm/llvm-project/commit/1bc58a258e2edb6221009a26d0f0037eda6c7c47
DIFF: https://github.com/llvm/llvm-project/commit/1bc58a258e2edb6221009a26d0f0037eda6c7c47.diff

LOG: Extend `getBackwardSlice` to track values captured from above (#113478)

This change modifies `getBackwardSlice` to track values captures by the
regions of each operation that it traverses. Ignoring values captured
from a parent region may lead to an incomplete program slice. However,
there seems to be logic that depends on not traversing captured values,
so this change preserves the default behavior by hiding this logic
behind the `omitUsesFromAbove` flag.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/SliceAnalysis.h
    mlir/lib/Analysis/SliceAnalysis.cpp
    mlir/test/IR/slice.mlir
    mlir/test/lib/IR/TestSlicing.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 99279fdfe427c8..a4f5d937cd51da 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -47,6 +47,11 @@ struct BackwardSliceOptions : public SliceOptions {
   /// backward slice computation traverses block arguments and asserts that the
   /// parent op has a single region with a single block.
   bool omitBlockArguments = false;
+
+  /// When omitUsesFromAbove is true, the backward slice computation omits
+  /// traversing values that are captured from above.
+  /// TODO: this should default to `false` after users have been updated.
+  bool omitUsesFromAbove = true;
 };
 
 using ForwardSliceOptions = SliceOptions;

diff  --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 2b1cf411ceeeeb..7ec999fa0370f9 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -16,6 +16,8 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -91,14 +93,13 @@ static void getBackwardSliceImpl(Operation *op,
   if (options.filter && !options.filter(op))
     return;
 
-  for (const auto &en : llvm::enumerate(op->getOperands())) {
-    auto operand = en.value();
-    if (auto *definingOp = operand.getDefiningOp()) {
+  auto processValue = [&](Value value) {
+    if (auto *definingOp = value.getDefiningOp()) {
       if (backwardSlice->count(definingOp) == 0)
         getBackwardSliceImpl(definingOp, backwardSlice, options);
-    } else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
+    } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
       if (options.omitBlockArguments)
-        continue;
+        return;
 
       Block *block = blockArg.getOwner();
       Operation *parentOp = block->getParentOp();
@@ -113,7 +114,14 @@ static void getBackwardSliceImpl(Operation *op,
     } else {
       llvm_unreachable("No definingOp and not a block argument.");
     }
+  };
+
+  if (!options.omitUsesFromAbove) {
+    visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
+      processValue(operand->get());
+    });
   }
+  llvm::for_each(op->getOperands(), processValue);
 
   backwardSlice->insert(op);
 }

diff  --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
index 0a32a0f231baf2..87d446c8f415af 100644
--- a/mlir/test/IR/slice.mlir
+++ b/mlir/test/IR/slice.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
+// RUN: mlir-opt -slice-analysis-test -split-input-file %s | FileCheck %s
 
 func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
   %a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
@@ -33,3 +33,29 @@ func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
 //   CHECK-DAG:   %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
 //   CHECK-DAG:   %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
 //       CHECK:   return
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %c2 = arith.constant 2 : index
+    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+    %2 = arith.addf %extracted, %extracted : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  return
+}
+
+// CHECK-LABEL: func @slice_use_from_above__backward_slice__0
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor 
+//       CHECK:   %[[A:.+]] = linalg.generic {{.*}} ins(%[[ARG0]]
+//       CHECK:   %[[B:.+]] = tensor.collapse_shape %[[A]]
+//       CHECK:   return

diff  --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
index c3d0d151c6d24d..e99d5976d6d9df 100644
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -39,6 +39,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
   SetVector<Operation *> slice;
   BackwardSliceOptions options;
   options.omitBlockArguments = omitBlockArguments;
+  // TODO: Make this default.
+  options.omitUsesFromAbove = false;
   getBackwardSlice(op, &slice, options);
   for (Operation *slicedOp : slice)
     builder.clone(*slicedOp, mapper);


        


More information about the Mlir-commits mailing list