[Mlir-commits] [mlir] 0a391c6 - [mlir][Analysis] Allow Slice Analysis to work with linalg::LinalgOp

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 10 18:54:57 PDT 2020


Author: MaheshRavishankar
Date: 2020-09-10T18:54:22-07:00
New Revision: 0a391c60793bae25804d2a82e5a26e2b9c7a69a1

URL: https://github.com/llvm/llvm-project/commit/0a391c60793bae25804d2a82e5a26e2b9c7a69a1
DIFF: https://github.com/llvm/llvm-project/commit/0a391c60793bae25804d2a82e5a26e2b9c7a69a1.diff

LOG: [mlir][Analysis] Allow Slice Analysis to work with linalg::LinalgOp

Differential Revision: https://reviews.llvm.org/D87307

Added: 
    mlir/test/IR/slice.mlir
    mlir/test/lib/IR/TestSlicing.cpp

Modified: 
    mlir/lib/Analysis/SliceAnalysis.cpp
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 8f5f87ba620e..120d4e4a9137 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
@@ -84,7 +85,8 @@ static void getBackwardSliceImpl(Operation *op,
   if (!op)
     return;
 
-  assert((op->getNumRegions() == 0 || isa<AffineForOp, scf::ForOp>(op)) &&
+  assert((op->getNumRegions() == 0 ||
+          isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) &&
          "unexpected generic op with regions");
 
   // Evaluate whether we should keep this def.

diff  --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
new file mode 100644
index 000000000000..731f3872f67d
--- /dev/null
+++ b/mlir/test/IR/slice.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
+
+func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %a = alloc(%arg0, %arg2) : memref<?x?xf32>
+  %b = alloc(%arg2, %arg1) : memref<?x?xf32>
+  %c = alloc(%arg0, %arg1) : memref<?x?xf32>
+  %d = alloc(%arg0, %arg1) : memref<?x?xf32>
+  linalg.matmul %a, %b, %c : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  linalg.matmul %a, %b, %d : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+  dealloc %c : memref<?x?xf32>
+  dealloc %b : memref<?x?xf32>
+  dealloc %a : memref<?x?xf32>
+  dealloc %d : memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+//   CHECK-DAG:   %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
+//   CHECK-DAG:   %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
+//   CHECK-DAG:   %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
+//       CHECK:   return
+
+// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+//   CHECK-DAG:   %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
+//   CHECK-DAG:   %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
+//   CHECK-DAG:   %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
+//       CHECK:   return

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index cf4ecada0f3c..a42f90bb9268 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTestIR
   TestPrintDefUse.cpp
   TestPrintNesting.cpp
   TestSideEffects.cpp
+  TestSlicing.cpp
   TestSymbolUses.cpp
   TestTypes.cpp
 

diff  --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
new file mode 100644
index 000000000000..a95b2f84cfcf
--- /dev/null
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -0,0 +1,81 @@
+//===- TestSlicing.cpp - Testing slice functionality ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a simple testing pass for slicing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+/// Create a function with the same signature as the parent function of `op`
+/// with name being the function name and a `suffix`.
+static LogicalResult createBackwardSliceFunction(Operation *op,
+                                                 StringRef suffix) {
+  FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
+  OpBuilder builder(parentFuncOp);
+  Location loc = op->getLoc();
+  std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
+  FuncOp clonedFuncOp =
+      builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
+  BlockAndValueMapping mapper;
+  builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
+  for (auto arg : enumerate(parentFuncOp.getArguments()))
+    mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
+  llvm::SetVector<Operation *> slice;
+  getBackwardSlice(op, &slice);
+  for (Operation *slicedOp : slice)
+    builder.clone(*slicedOp, mapper);
+  builder.create<ReturnOp>(loc);
+  return success();
+}
+
+namespace {
+/// Pass to test slice generated from slice analysis.
+struct SliceAnalysisTestPass
+    : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+  SliceAnalysisTestPass() = default;
+  SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
+};
+} // namespace
+
+void SliceAnalysisTestPass::runOnOperation() {
+  ModuleOp module = getOperation();
+  auto funcOps = module.getOps<FuncOp>();
+  unsigned opNum = 0;
+  for (auto funcOp : funcOps) {
+    // TODO: For now this is just looking for Linalg ops. It can be generalized
+    // to look for other ops using flags.
+    funcOp.walk([&](Operation *op) {
+      if (!isa<linalg::LinalgOp>(op))
+        return WalkResult::advance();
+      std::string append =
+          std::string("__backward_slice__") + std::to_string(opNum);
+      createBackwardSliceFunction(op, append);
+      opNum++;
+      return WalkResult::advance();
+    });
+  }
+}
+
+namespace mlir {
+void registerSliceAnalysisTestPass() {
+  PassRegistration<SliceAnalysisTestPass> pass(
+      "slice-analysis-test", "Test Slice analysis functionality.");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 437b5f4b6f1a..e46327aa6399 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -38,6 +38,7 @@ void registerPatternsTestPass();
 void registerPrintOpAvailabilityPass();
 void registerSideEffectTestPasses();
 void registerSimpleParametricTilingPass();
+void registerSliceAnalysisTestPass();
 void registerSymbolTestPasses();
 void registerTestAffineDataCopyPass();
 void registerTestAffineLoopUnswitchingPass();
@@ -88,6 +89,7 @@ void registerTestPasses() {
   registerPrintOpAvailabilityPass();
   registerSideEffectTestPasses();
   registerSimpleParametricTilingPass();
+  registerSliceAnalysisTestPass();
   registerSymbolTestPasses();
   registerTestAffineDataCopyPass();
   registerTestAllReduceLoweringPass();


        


More information about the Mlir-commits mailing list