[Mlir-commits] [mlir] [mlir] Use transform dialect for backwardslice tests (PR #159634)
Ian Wood
llvmlistbot at llvm.org
Thu Sep 18 12:37:15 PDT 2025
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/159634
The current `SliceAnalysisTestPass` looks for linalg ops to be the start of the backward slice. Using transform dialect ops instead allows specifying which op to start the backward slice from. This will be useful for https://github.com/llvm/llvm-project/pull/158135 where I want to test a isolated from above operation. It also means that this test doesn't need to depend on a specific dialect (i.e. linalg).
The current tests were moved into mlir/test/Analysis/test-backwardslice.mlir and were rewritten to use unregistered ops instead of linalg/memref ops.
>From f08d4b4766b39dce131d8df1b4870a287f0ce36c Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood at u.northwestern.edu>
Date: Thu, 18 Sep 2025 12:26:57 -0700
Subject: [PATCH] [mlir] Use transform dialect for backwardslice tests
Signed-off-by: Ian Wood <ianwood at u.northwestern.edu>
---
mlir/test/Analysis/test-backwardslice.mlir | 97 +++++++++++++++++++++
mlir/test/IR/slice.mlir | 61 -------------
mlir/test/IR/slice_multiple_blocks.mlir | 36 --------
mlir/test/lib/Analysis/CMakeLists.txt | 9 ++
mlir/test/lib/Analysis/TestAnalysisOps.cpp | 98 +++++++++++++++++++++
mlir/test/lib/Analysis/TestAnalysisOps.td | 39 +++++++++
mlir/test/lib/Analysis/lit.local.cfg | 1 +
mlir/test/lib/IR/CMakeLists.txt | 1 -
mlir/test/lib/IR/TestSlicing.cpp | 99 ----------------------
mlir/tools/mlir-opt/mlir-opt.cpp | 4 +-
10 files changed, 246 insertions(+), 199 deletions(-)
create mode 100644 mlir/test/Analysis/test-backwardslice.mlir
delete mode 100644 mlir/test/IR/slice.mlir
delete mode 100644 mlir/test/IR/slice_multiple_blocks.mlir
create mode 100644 mlir/test/lib/Analysis/TestAnalysisOps.cpp
create mode 100644 mlir/test/lib/Analysis/TestAnalysisOps.td
create mode 100644 mlir/test/lib/Analysis/lit.local.cfg
delete mode 100644 mlir/test/lib/IR/TestSlicing.cpp
diff --git a/mlir/test/Analysis/test-backwardslice.mlir b/mlir/test/Analysis/test-backwardslice.mlir
new file mode 100644
index 0000000000000..5a72aacb94e06
--- /dev/null
+++ b/mlir/test/Analysis/test-backwardslice.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+func.func @simple() {
+ %0 = "other"() : () -> (f32)
+ %1 = "root"(%0) : (f32) -> (f32)
+}
+// CHECK-LABEL: func @simple__backward_slice__()
+// CHECK: %[[OTHER:.+]] = "other"
+// CHECK: %[[ROOT:.+]] = "root"(%[[OTHER]])
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op = transform.structured.match ops{["root"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ transform.test.get_backward_slice %op : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @across_blocks() {
+ %0 = "other"() : () -> (f32)
+ cf.br ^bb1
+^bb1() :
+ %1 = "root"(%0) : (f32) -> (f32)
+}
+// CHECK-LABEL: func @across_blocks__backward_slice__()
+// CHECK: %[[OTHER:.+]] = "other"
+// CHECK: %[[ROOT:.+]] = "root"(%[[OTHER]])
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op = transform.structured.match ops{["root"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ transform.test.get_backward_slice %op : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @large_slice() {
+ %0 = "not_in_slice"() : () -> (f32)
+ %1 = "sliced_op0"() : () -> (f32)
+ %2 = "sliced_op1"() : () -> (f32)
+ %3 = "sliced_op"(%1, %2) : (f32, f32) -> (f32)
+ %4 = "not_in_slice"() : () -> (f32)
+ %5 = "root"(%3) : (f32) -> (f32)
+ %6 = "not_in_slice"() : () -> (f32)
+}
+// CHECK-LABEL: func @large_slice__backward_slice__()
+// CHECK-NOT: "not_in_slice"
+// CHECK-DAG: %[[OP0:.+]] = "sliced_op0"
+// CHECK-DAG: %[[OP1:.+]] = "sliced_op1"
+// CHECK-NOT: "not_in_slice"
+// CHECK: %[[OP2:.+]] = "sliced_op"(%[[OP0]], %[[OP1]])
+// CHECK: %[[ROOT:.+]] = "root"(%[[OP2]])
+// CHECK-NOT: "not_in_slice"
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op = transform.structured.match ops{["root"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ transform.test.get_backward_slice %op : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @include_uses_from_above() {
+ %0 = "sliced_op"() : () -> (f32)
+ %1 = "sliced_op" () ({
+ ^bb0():
+ "yield" (%0) : (f32) -> ()
+ }): () -> (f32)
+ %2 = "root"(%1) : (f32) -> (f32)
+}
+// CHECK-LABEL: func @include_uses_from_above__backward_slice__()
+// CHECK: %[[OP0:.+]] = "sliced_op"
+// CHECK: %[[OP1:.+]] = "sliced_op"
+// CHECK-NEXT: "yield"(%[[OP0]])
+// CHECK: %[[ROOT:.+]] = "root"(%[[OP1]])
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op = transform.structured.match ops{["root"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ transform.test.get_backward_slice %op : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
deleted file mode 100644
index 87d446c8f415a..0000000000000
--- a/mlir/test/IR/slice.mlir
+++ /dev/null
@@ -1,61 +0,0 @@
-// RUN: mlir-opt -slice-analysis-test -split-input-file %s | FileCheck %s
-
-func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
- %a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
- %b = memref.alloc(%arg2, %arg1) : memref<?x?xf32>
- %c = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
- %d = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
- linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
- outs(%c : memref<?x?xf32>)
- linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
- outs(%d : memref<?x?xf32>)
- memref.dealloc %c : memref<?x?xf32>
- memref.dealloc %b : memref<?x?xf32>
- memref.dealloc %a : memref<?x?xf32>
- memref.dealloc %d : memref<?x?xf32>
- return
-}
-
-// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
-// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK: return
-
-// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
-// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK: return
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
- %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%in: f32, %out: f32):
- %2 = arith.addf %in, %in : f32
- linalg.yield %2 : f32
- } -> tensor<5x5xf32>
- %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%in: f32, %out: f32):
- %c2 = arith.constant 2 : index
- %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
- %2 = arith.addf %extracted, %extracted : f32
- linalg.yield %2 : f32
- } -> tensor<5x5xf32>
- return
-}
-
-// CHECK-LABEL: func @slice_use_from_above__backward_slice__0
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor
-// CHECK: %[[A:.+]] = linalg.generic {{.*}} ins(%[[ARG0]]
-// CHECK: %[[B:.+]] = tensor.collapse_shape %[[A]]
-// CHECK: return
diff --git a/mlir/test/IR/slice_multiple_blocks.mlir b/mlir/test/IR/slice_multiple_blocks.mlir
deleted file mode 100644
index 395a4e970d5d4..0000000000000
--- a/mlir/test/IR/slice_multiple_blocks.mlir
+++ /dev/null
@@ -1,36 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(slice-analysis-test{omit-block-arguments=true})" %s | FileCheck %s
-
-func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
- %a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
- %b = memref.alloc(%arg2, %arg1) : memref<?x?xf32>
- cf.br ^bb1
-^bb1() :
- %c = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
- %d = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
- linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
- outs(%c : memref<?x?xf32>)
- linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
- outs(%d : memref<?x?xf32>)
- memref.dealloc %c : memref<?x?xf32>
- memref.dealloc %b : memref<?x?xf32>
- memref.dealloc %a : memref<?x?xf32>
- memref.dealloc %d : memref<?x?xf32>
- return
-}
-// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
-// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK: return
-
-// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
-// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
-// CHECK: return
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 91879981bffd2..840f08385ccc5 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_TARGET_DEFINITIONS TestAnalysisOps.td)
+mlir_tablegen(TestAnalysisOps.h.inc -gen-op-decls)
+mlir_tablegen(TestAnalysisOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTestAnalysisOpsIncGen)
+
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestAnalysis
TestAliasAnalysis.cpp
@@ -11,6 +16,7 @@ add_mlir_library(MLIRTestAnalysis
TestMemRefStrideCalculation.cpp
TestSlice.cpp
TestTopologicalSort.cpp
+ TestAnalysisOps.cpp
DataFlow/TestDeadCodeAnalysis.cpp
DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -20,6 +26,9 @@ add_mlir_library(MLIRTestAnalysis
EXCLUDE_FROM_LIBMLIR
+ DEPENDS
+ MLIRTestAnalysisOpsIncGen
+
LINK_LIBS PUBLIC
MLIRTestDialect
)
diff --git a/mlir/test/lib/Analysis/TestAnalysisOps.cpp b/mlir/test/lib/Analysis/TestAnalysisOps.cpp
new file mode 100644
index 0000000000000..22435f3c22079
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestAnalysisOps.cpp
@@ -0,0 +1,98 @@
+//===- TestAnalysisOps.cpp - Test Transforms ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines transform dialect operations for testing MLIR
+// analyses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "TestAnalysisOps.h.inc"
+
+using namespace mlir;
+using namespace transform;
+
+#define GET_OP_CLASSES
+#include "TestAnalysisOps.cpp.inc"
+
+/// Create a function with the same signature as the parent function of `op`
+/// with name being the function name and a `suffix`.
+static LogicalResult
+createBackwardSliceFunction(Operation *op, StringRef suffix,
+ const BackwardSliceOptions &options) {
+ func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>();
+ if (!parentFuncOp)
+ return failure();
+ OpBuilder builder(parentFuncOp);
+ Location loc = op->getLoc();
+ std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
+ func::FuncOp clonedFuncOp = func::FuncOp::create(
+ builder, loc, clonedFuncOpName, parentFuncOp.getFunctionType());
+ IRMapping mapper;
+ builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
+ for (const auto &arg : enumerate(parentFuncOp.getArguments()))
+ mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
+ SetVector<Operation *> slice;
+ LogicalResult result = getBackwardSlice(op, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
+ (void)result;
+ for (Operation *slicedOp : slice)
+ builder.clone(*slicedOp, mapper);
+ func::ReturnOp::create(builder, loc);
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestGetBackwardSlice::apply(TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ Operation *op = *state.getPayloadOps(getOp()).begin();
+ StringRef suffix = "__backward_slice__";
+ BackwardSliceOptions options;
+ options.omitBlockArguments = true;
+ // TODO: Make this default.
+ options.omitUsesFromAbove = false;
+ options.inclusive = true;
+ if (failed(createBackwardSliceFunction(op, suffix, options)))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Extension
+//===----------------------------------------------------------------------===//
+namespace {
+
+class TestAnalysisDialectExtension
+ : public transform::TransformDialectExtension<
+ TestAnalysisDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAnalysisDialectExtension)
+ using Base::Base;
+
+ void init() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "TestAnalysisOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+namespace test {
+void registerTestAnalysisTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<TestAnalysisDialectExtension>();
+}
+} // namespace test
diff --git a/mlir/test/lib/Analysis/TestAnalysisOps.td b/mlir/test/lib/Analysis/TestAnalysisOps.td
new file mode 100644
index 0000000000000..d8aa66a69db9e
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestAnalysisOps.td
@@ -0,0 +1,39 @@
+//===- TestAnalysisOps.td ---------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_ANALYSIS_OPS
+#define TEST_ANALYSIS_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+/// Transform dialect operations for testing analysis in MLIR
+
+def TestGetBackwardSlice :
+ Op<Transform_Dialect, "test.get_backward_slice",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Test `getBackwardSlice` by cloning the slice starting at `op`.
+ }];
+
+ let arguments =
+ (ins TransformHandleTypeInterface:$op);
+
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $op attr-dict `:` type($op)
+ }];
+}
+
+#endif // TEST_ANALYSIS_OPS
diff --git a/mlir/test/lib/Analysis/lit.local.cfg b/mlir/test/lib/Analysis/lit.local.cfg
new file mode 100644
index 0000000000000..65a7f202dc82a
--- /dev/null
+++ b/mlir/test/lib/Analysis/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove(".td")
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 1abcfc77d2d9b..2ab2e62885d40 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -18,7 +18,6 @@ add_mlir_library(MLIRTestIR
TestPrintInvalid.cpp
TestPrintNesting.cpp
TestSideEffects.cpp
- TestSlicing.cpp
TestSymbolUses.cpp
TestRegions.cpp
TestTypes.cpp
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
deleted file mode 100644
index 5a5ac450f91fb..0000000000000
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ /dev/null
@@ -1,99 +0,0 @@
-//===- TestSlicing.cpp - Testing slice functionality ----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a simple testing pass for slicing.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-
-using namespace mlir;
-
-/// Create a function with the same signature as the parent function of `op`
-/// with name being the function name and a `suffix`.
-static LogicalResult createBackwardSliceFunction(Operation *op,
- StringRef suffix,
- bool omitBlockArguments) {
- func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>();
- OpBuilder builder(parentFuncOp);
- Location loc = op->getLoc();
- std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
- func::FuncOp clonedFuncOp = func::FuncOp::create(
- builder, loc, clonedFuncOpName, parentFuncOp.getFunctionType());
- IRMapping mapper;
- builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
- for (const auto &arg : enumerate(parentFuncOp.getArguments()))
- mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
- SetVector<Operation *> slice;
- BackwardSliceOptions options;
- options.omitBlockArguments = omitBlockArguments;
- // TODO: Make this default.
- options.omitUsesFromAbove = false;
- LogicalResult result = getBackwardSlice(op, &slice, options);
- assert(result.succeeded() && "expected a backward slice");
- (void)result;
- for (Operation *slicedOp : slice)
- builder.clone(*slicedOp, mapper);
- func::ReturnOp::create(builder, loc);
- return success();
-}
-
-namespace {
-/// Pass to test slice generated from slice analysis.
-struct SliceAnalysisTestPass
- : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SliceAnalysisTestPass)
-
- StringRef getArgument() const final { return "slice-analysis-test"; }
- StringRef getDescription() const final {
- return "Test Slice analysis functionality.";
- }
-
- Option<bool> omitBlockArguments{
- *this, "omit-block-arguments",
- llvm::cl::desc("Test Slice analysis with multiple blocks but slice "
- "omiting block arguments"),
- llvm::cl::init(true)};
-
- void runOnOperation() override;
- SliceAnalysisTestPass() = default;
- SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
-};
-} // namespace
-
-void SliceAnalysisTestPass::runOnOperation() {
- ModuleOp module = getOperation();
- auto funcOps = module.getOps<func::FuncOp>();
- unsigned opNum = 0;
- for (auto funcOp : funcOps) {
- // TODO: For now this is just looking for Linalg ops. It can be generalized
- // to look for other ops using flags.
- funcOp.walk([&](Operation *op) {
- if (!isa<linalg::LinalgOp>(op))
- return WalkResult::advance();
- std::string append =
- std::string("__backward_slice__") + std::to_string(opNum);
- (void)createBackwardSliceFunction(op, append, omitBlockArguments);
- opNum++;
- return WalkResult::advance();
- });
- }
-}
-
-namespace mlir {
-void registerSliceAnalysisTestPass() {
- PassRegistration<SliceAnalysisTestPass>();
-}
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e4620c009af8c..a66bae176df64 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -42,7 +42,6 @@ void registerRegionTestPasses();
void registerPrintTosaAvailabilityPass();
void registerShapeFunctionTestPasses();
void registerSideEffectTestPasses();
-void registerSliceAnalysisTestPass();
void registerSymbolTestPasses();
void registerTestAffineAccessAnalysisPass();
void registerTestAffineDataCopyPass();
@@ -176,6 +175,7 @@ void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
void registerIrdlTestDialect(DialectRegistry &);
void registerTestTransformsTransformDialectExtension(DialectRegistry &);
+void registerTestAnalysisTransformDialectExtension(DialectRegistry &);
} // namespace test
#ifdef MLIR_INCLUDE_TESTS
@@ -190,7 +190,6 @@ void registerTestPasses() {
registerRegionTestPasses();
registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
- registerSliceAnalysisTestPass();
registerSymbolTestPasses();
registerTestAffineAccessAnalysisPass();
registerTestAffineDataCopyPass();
@@ -337,6 +336,7 @@ int main(int argc, char **argv) {
::test::registerTestTilingInterfaceTransformDialectExtension(registry);
::test::registerTestTransformDialectExtension(registry);
::test::registerTestTransformsTransformDialectExtension(registry);
+ ::test::registerTestAnalysisTransformDialectExtension(registry);
#endif
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "MLIR modular optimizer driver\n", registry));
More information about the Mlir-commits
mailing list