[Mlir-commits] [mlir] 0a391c6 - [mlir][Analysis] Allow Slice Analysis to work with linalg::LinalgOp
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 10 18:54:57 PDT 2020
Author: MaheshRavishankar
Date: 2020-09-10T18:54:22-07:00
New Revision: 0a391c60793bae25804d2a82e5a26e2b9c7a69a1
URL: https://github.com/llvm/llvm-project/commit/0a391c60793bae25804d2a82e5a26e2b9c7a69a1
DIFF: https://github.com/llvm/llvm-project/commit/0a391c60793bae25804d2a82e5a26e2b9c7a69a1.diff
LOG: [mlir][Analysis] Allow Slice Analysis to work with linalg::LinalgOp
Differential Revision: https://reviews.llvm.org/D87307
Added:
mlir/test/IR/slice.mlir
mlir/test/lib/IR/TestSlicing.cpp
Modified:
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/test/lib/IR/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8f5f87ba620e..120d4e4a9137 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -12,6 +12,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
@@ -84,7 +85,8 @@ static void getBackwardSliceImpl(Operation *op,
if (!op)
return;
- assert((op->getNumRegions() == 0 || isa<AffineForOp, scf::ForOp>(op)) &&
+ assert((op->getNumRegions() == 0 ||
+ isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) &&
"unexpected generic op with regions");
// Evaluate whether we should keep this def.
diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
new file mode 100644
index 000000000000..731f3872f67d
--- /dev/null
+++ b/mlir/test/IR/slice.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
+
+func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %a = alloc(%arg0, %arg2) : memref<?x?xf32>
+ %b = alloc(%arg2, %arg1) : memref<?x?xf32>
+ %c = alloc(%arg0, %arg1) : memref<?x?xf32>
+ %d = alloc(%arg0, %arg1) : memref<?x?xf32>
+ linalg.matmul %a, %b, %c : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ linalg.matmul %a, %b, %d : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ dealloc %c : memref<?x?xf32>
+ dealloc %b : memref<?x?xf32>
+ dealloc %a : memref<?x?xf32>
+ dealloc %d : memref<?x?xf32>
+ return
+}
+
+// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
+// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK: return
+
+// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
+// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK: return
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index cf4ecada0f3c..a42f90bb9268 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTestIR
TestPrintDefUse.cpp
TestPrintNesting.cpp
TestSideEffects.cpp
+ TestSlicing.cpp
TestSymbolUses.cpp
TestTypes.cpp
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
new file mode 100644
index 000000000000..a95b2f84cfcf
--- /dev/null
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -0,0 +1,81 @@
+//===- TestSlicing.cpp - Testing slice functionality ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a simple testing pass for slicing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+/// Create a function with the same signature as the parent function of `op`
+/// with name being the function name and a `suffix`.
+static LogicalResult createBackwardSliceFunction(Operation *op,
+ StringRef suffix) {
+ FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
+ OpBuilder builder(parentFuncOp);
+ Location loc = op->getLoc();
+ std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
+ FuncOp clonedFuncOp =
+ builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
+ BlockAndValueMapping mapper;
+ builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
+ for (auto arg : enumerate(parentFuncOp.getArguments()))
+ mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
+ llvm::SetVector<Operation *> slice;
+ getBackwardSlice(op, &slice);
+ for (Operation *slicedOp : slice)
+ builder.clone(*slicedOp, mapper);
+ builder.create<ReturnOp>(loc);
+ return success();
+}
+
+namespace {
+/// Pass to test slice generated from slice analysis.
+struct SliceAnalysisTestPass
+ : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
+ void runOnOperation() override;
+ SliceAnalysisTestPass() = default;
+ SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
+};
+} // namespace
+
+void SliceAnalysisTestPass::runOnOperation() {
+ ModuleOp module = getOperation();
+ auto funcOps = module.getOps<FuncOp>();
+ unsigned opNum = 0;
+ for (auto funcOp : funcOps) {
+ // TODO: For now this is just looking for Linalg ops. It can be generalized
+ // to look for other ops using flags.
+ funcOp.walk([&](Operation *op) {
+ if (!isa<linalg::LinalgOp>(op))
+ return WalkResult::advance();
+ std::string append =
+ std::string("__backward_slice__") + std::to_string(opNum);
+ createBackwardSliceFunction(op, append);
+ opNum++;
+ return WalkResult::advance();
+ });
+ }
+}
+
+namespace mlir {
+void registerSliceAnalysisTestPass() {
+ PassRegistration<SliceAnalysisTestPass> pass(
+ "slice-analysis-test", "Test Slice analysis functionality.");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 437b5f4b6f1a..e46327aa6399 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -38,6 +38,7 @@ void registerPatternsTestPass();
void registerPrintOpAvailabilityPass();
void registerSideEffectTestPasses();
void registerSimpleParametricTilingPass();
+void registerSliceAnalysisTestPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAffineLoopUnswitchingPass();
@@ -88,6 +89,7 @@ void registerTestPasses() {
registerPrintOpAvailabilityPass();
registerSideEffectTestPasses();
registerSimpleParametricTilingPass();
+ registerSliceAnalysisTestPass();
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
More information about the Mlir-commits
mailing list