[Mlir-commits] [mlir] 9f77909 - [mlir][shape] add outline-shape-computation pass
Jacques Pienaar
llvmlistbot at llvm.org
Sun Oct 2 20:24:59 PDT 2022
Author: Yuanqiang Liu
Date: 2022-10-02T20:24:49-07:00
New Revision: 9f77909a5e07b7973fe13d4ea6391c29ff1b46b5
URL: https://github.com/llvm/llvm-project/commit/9f77909a5e07b7973fe13d4ea6391c29ff1b46b5
DIFF: https://github.com/llvm/llvm-project/commit/9f77909a5e07b7973fe13d4ea6391c29ff1b46b5.diff
LOG: [mlir][shape] add outline-shape-computation pass
Add outline-shape-computation pass. This pass his pass outlines the
shape computation part in high level IR by adding shape.func and
populate corresponding mapping information into ShapeMappingAnalysis.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D131810
Added:
mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h
mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
mlir/test/Dialect/Shape/outline-shape-computation.mlir
mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp
Modified:
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Shape/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h b/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h
new file mode 100644
index 0000000000000..25befa3590856
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h
@@ -0,0 +1,60 @@
+//===- ShapeMappingAnalysis.h - Preserve shape Info ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
+#define MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+namespace shape {
+
+/// ShapeMappingValue works as the value of ShapeMappingAnalysis table, where
+/// `funcSymbol` is the symbol of mapping function, and `inputs` are the actual
+/// parameters for the function.
+struct ShapeMappingValue {
+ ShapeMappingValue() = default;
+ ShapeMappingValue(FlatSymbolRefAttr symbol, llvm::SmallVector<Value> &&inps)
+ : funcSymbol(symbol), inputs(inps) {}
+
+ FlatSymbolRefAttr funcSymbol;
+ llvm::SmallVector<Value> inputs;
+};
+
+/// ShapeMappingAnalysis is used together with OutlineShapeComputationPass to
+/// preserve Value and corresponding shape function / arguments mapping
+/// information
+struct ShapeMappingAnalysis {
+ ShapeMappingAnalysis(Operation *op) : operation(op) { (void)operation; }
+
+ /// Dumps the shape mapping information to the given stream.
+ void print(raw_ostream &os) const {
+ os << "// ---- Shape Mapping Information -----\n";
+ for (const auto &it : shapeMapping) {
+ const ShapeMappingValue &mappingValue = it.second;
+ os << "// Shape for " << it.first << " :: " << mappingValue.funcSymbol;
+ llvm::interleaveComma(mappingValue.inputs, os << "(");
+ os << ")\n";
+ }
+ }
+
+ llvm::DenseMap<Value, ShapeMappingValue> shapeMapping;
+
+private:
+ Operation *operation;
+};
+
+} // namespace shape
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index be3c74123d065..cfb637f133f54 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -18,6 +18,7 @@
namespace mlir {
class ConversionTarget;
+class ModuleOp;
class TypeConverter;
namespace func {
class FuncOp;
@@ -53,6 +54,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
// level.
std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass();
+/// Outline the shape computation part by adding shape.func and populate
+/// conrresponding mapping infomation into ShapeMappingAnalysis.
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 503780f0d1c7c..9dfda9ea33615 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -11,6 +11,88 @@
include "mlir/Pass/PassBase.td"
+def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> {
+ let summary = "Using shape.func to preserve shape computation";
+ let description = [{
+ This pass outlines the shape computation part in high level IR by adding
+ shape.func and populate corresponding mapping infoemation into
+ ShapeMappingAnalysis. The shape computation part is usually introduced by
+ shape reification, and each single dynamic shape is denoted by shape.with_shape.
+
+ There're two main reasons this shape-outline pass is needed:
+ 1. Many passes don't take shape reification part into consideration.
+ Therefore we need to "remove" the shape reification part temporarily for
+ these passes.
+ 2. Sometimes we cannot redo shape reification after converting from dialect
+ A to dialect B. Because op-level shape reification is only implemented
+ on A.
+
+ Input:
+
+ ```mlir
+ func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
+ tensor<?x4x?xf32> {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+ %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
+ %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+ %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
+ %4 = shape.value_of %3 : tensor<?x4x?xf32>
+ %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
+ tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+ %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+ %7 = arith.addi %6, %c2 : index
+ %8 = shape.from_extents %7, %c4, %1 : index, index, index
+ %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
+ %10 = shape.value_of %9 : tensor<?x4x?xf32>
+ return %10 : tensor<?x4x?xf32>
+ }
+ ```
+
+ Output
+ ```mlir
+ func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
+ tensor<?x4x?xf32> {
+ %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+ %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
+ tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+ }
+ shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+ %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index
+ %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index
+ %3 = arith.addi %2, %c2 : index
+ %4 = from_extents %3, %c4, %1 : index, index, index
+ return %4 : !shape.shape
+ }
+ shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
+ %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+ return %0 : tensor<3xindex>
+ }
+ ```
+
+ For the above example, the shape computation is inlined in the input IR,
+ which is used for two values' (test.abs and test.concat) shape. And the shape
+ compuatation part is outlined in the output IR.
+
+ And the shape mapping infomation will be:
+
+ ```
+ // ---- Shape Mapping Infomation -----
+ // - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+ // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+ ```
+ }];
+ let constructor = "mlir::createOutlineShapeComputationPass()";
+ let dependentDialects = ["shape::ShapeDialect"];
+}
+
def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
let summary = "Replace all cstr_ ops with a true witness";
let constructor = "mlir::createRemoveShapeConstraintsPass()";
diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index 40f1bd1941347..7c9b0d2e5e3a8 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRShapeOpsTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ OutlineShapeComputation.cpp
RemoveShapeConstraints.cpp
ShapeToShapeLowering.cpp
diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
new file mode 100644
index 0000000000000..5d598a6f88109
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -0,0 +1,318 @@
+//====----- OutlineShapeComputation.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Debug.h"
+#include <queue>
+#include <unordered_set>
+#include <vector>
+
+namespace mlir {
+#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "outline-shape-computation"
+
+using namespace mlir;
+
+namespace {
+
+// A Value is an input of the cluster if it is an operand of an operation in the
+// cluster and its defining operation is not in the cluster.
+SmallVector<Value, 4>
+getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
+ SmallVector<Value, 4> inputs;
+ llvm::SmallDenseSet<Value> inputSet;
+ llvm::SmallDenseSet<Operation *> opSet;
+ for (Operation *op : cluster) {
+ bool inserted = opSet.insert(op).second;
+ (void)inserted;
+ assert(inserted && "cluster contains duplicate operations");
+ }
+
+ for (Operation *op : cluster) {
+ for (Value operand : op->getOperands()) {
+ Operation *operandOp = operand.getDefiningOp();
+ if (opSet.find(operandOp) != opSet.end()) {
+ // Skip if defining op is in the cluster.
+ continue;
+ }
+ if (inputSet.insert(operand).second)
+ inputs.push_back(operand);
+ }
+ }
+ return inputs;
+}
+
+// Create a shape.func representing the shape computation for `shape`.
+std::pair<shape::FuncOp, SmallVector<Value>>
+createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
+ Value shape, StringRef fnName, Location loc) {
+ SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
+ auto fnType =
+ cluster.empty()
+ ? b.getFunctionType(shape.getType(), shape.getType())
+ : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
+ shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
+ Block *block = fnOp.addEntryBlock();
+ b.setInsertionPoint(block, block->end());
+ BlockAndValueMapping bvm;
+ if (cluster.empty()) {
+ bvm.map(shape, fnOp.getArgument(0));
+ } else {
+ for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
+ bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
+ }
+
+ for (Operation *op : cluster)
+ b.clone(*op, bvm);
+ llvm::SmallVector<Value, 4> fnReturns;
+ fnReturns.push_back(bvm.lookupOrDefault(shape));
+
+ b.create<shape::ReturnOp>(loc, fnReturns);
+ fnOp.setPrivate();
+ return std::make_pair(fnOp, inputs);
+}
+
+// The operations in the cluster might be unsorted, which could be inconvenient
+// when creating shape.func op.
+DenseMap<Value, SmallVector<Operation *, 8>>
+getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
+ func::FuncOp funcOp) {
+ // Compute all clusters that each operation is in
+ DenseMap<Operation *, SmallVector<Value>> op2Shapes;
+ for (const auto &it : clusters) {
+ Value shape = it.first;
+ const DenseSet<Operation *> &cluster = it.second;
+ for (Operation *cOp : cluster)
+ op2Shapes[cOp].push_back(shape);
+ }
+
+ // Iterate through all operations in order. Get all the clusters `cOp` belongs
+ // to and construct the new ordered cluster as it traverses.
+ DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
+ funcOp.walk([&](Operation *op) {
+ auto it = op2Shapes.find(op);
+ if (it != op2Shapes.end()) {
+ Operation *cOp = it->first;
+ for (Value shape : it->second)
+ orderedClusters[shape].push_back(cOp);
+ }
+ });
+
+ return orderedClusters;
+}
+
+void constructShapeFunc(
+ const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
+ DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
+ SymbolTable &symbolTable,
+ DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
+ func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
+ std::string shapeCalculationNamePrefix = "shape_cal_";
+ int shapeCalculationNameIdx = 0;
+ OpBuilder builder(context);
+
+ // Construct a shape function
+ for (shape::WithOp withOp : allWithOps) {
+ Value value = withOp.getOperand();
+ Value shape = withOp.getShape();
+ RankedTensorType rankedType = value.getType().dyn_cast<RankedTensorType>();
+ if (rankedType == nullptr)
+ continue;
+
+ const SmallVector<Operation *, 8> &cluster = clusters[shape];
+ shape::ShapeMappingValue shapeMappingValue;
+ auto it = dynShape2ShapeFunc.find(shape);
+ if (it == dynShape2ShapeFunc.end()) {
+ std::string name = shapeCalculationNamePrefix +
+ std::to_string(shapeCalculationNameIdx++);
+ Location loc = value.getLoc();
+ builder.setInsertionPointAfter(funcOp);
+ auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
+ const SmallVector<Value> &inputs = pair.second;
+ shape::FuncOp shapeFuncOp = pair.first;
+ StringAttr insertedName = symbolTable.insert(shapeFuncOp);
+ auto symbol = FlatSymbolRefAttr::get(context, insertedName);
+
+ shapeMappingValue.funcSymbol = symbol;
+ shapeMappingValue.inputs = inputs;
+ } else {
+ shapeMappingValue = it->second;
+ }
+ dynShape2ShapeFunc[shape] = shapeMappingValue;
+ shapeMappingAnalysis.shapeMapping.insert(
+ std::make_pair(value, shapeMappingValue));
+ }
+}
+
+struct OutlineShapeComputationPass
+ : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
+
+ void runOnOperation() override;
+
+private:
+ bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
+
+ void getClusterFromValue(Value shape,
+ DenseMap<Value, DenseSet<Operation *>> &clusters);
+
+ DenseMap<Value, SmallVector<Operation *, 8>>
+ constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
+ func::FuncOp funcOp);
+
+ DenseSet<Operation *> onlyUsedByWithShapes;
+};
+
+class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::DimOp op,
+ PatternRewriter &rewriter) const override {
+ auto shapeOf =
+ rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
+ rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
+ op.getIndex());
+ return success();
+ }
+};
+
+void OutlineShapeComputationPass::runOnOperation() {
+ ModuleOp moduleOp = getOperation();
+ SymbolTable symbolTable(moduleOp);
+ DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
+ auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
+ // TODO: This is as we populate this analysis during a pass that mutates. This
+ // pass currently requires 1 single module being compiled.
+ shapeMappingAnalysis.shapeMapping.clear();
+ markAnalysesPreserved<shape::ShapeMappingAnalysis>();
+
+ moduleOp.walk([&](func::FuncOp funcOp) {
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet prevPatterns(context);
+ prevPatterns.insert<TensorDimOpRewriter>(context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
+ return signalPassFailure();
+
+ // initialize class member `onlyUsedByWithShapes`
+ onlyUsedByWithShapes.clear();
+ funcOp.walk([&](Operation *op) {
+ calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
+ });
+ LLVM_DEBUG({
+ llvm::dbgs() << "onlyUsedByWithShapes table: \n";
+ for (auto it : onlyUsedByWithShapes)
+ llvm::dbgs() << *it << "\n";
+ });
+
+ // collect all the shape.with_shape ops.
+ std::vector<shape::WithOp> allWithOps;
+ funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
+
+ DenseMap<Value, SmallVector<Operation *, 8>> clusters =
+ constructClustersForEachShape(allWithOps, funcOp);
+ constructShapeFunc(allWithOps, context, clusters, symbolTable,
+ dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
+
+ for (shape::WithOp withOp : allWithOps) {
+ Value value = withOp.getOperand();
+ for (Operation *user : withOp.getResult().getUsers()) {
+ if (Value valueOf = llvm::dyn_cast<shape::ValueOfOp>(user))
+ valueOf.replaceAllUsesExcept(value, withOp);
+ }
+ }
+
+ // Apply patterns, note this also performs DCE.
+ if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
+ return signalPassFailure();
+ });
+}
+
+DenseMap<Value, SmallVector<Operation *, 8>>
+OutlineShapeComputationPass::constructClustersForEachShape(
+ const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
+ DenseMap<Value, DenseSet<Operation *>> clusters;
+ for (shape::WithOp withOp : allWithOps) {
+ Value shape = withOp.getShape();
+ if (clusters.count(shape) == 0)
+ getClusterFromValue(shape, clusters);
+ }
+ return getOrderedClusters(clusters, funcOp);
+}
+
+// The output of a cluster is the `shape`, and the inputs are the outputs of
+// operations who are not in `onlyUsedByWithShapes`
+void OutlineShapeComputationPass::getClusterFromValue(
+ Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
+ DenseSet<Operation *> cluster;
+
+ DenseSet<Operation *> visited;
+ std::queue<Operation *> queue;
+
+ // defOp == nullptr means shape is the argument of the func op
+ if (Operation *defOp = shape.getDefiningOp()) {
+ visited.insert(defOp);
+ queue.push(defOp);
+ }
+ while (!queue.empty()) {
+ Operation *op = queue.front();
+ queue.pop();
+ if (onlyUsedByWithShapes.contains(op)) {
+ cluster.insert(op);
+ for (Value inp : op->getOperands()) {
+ Operation *inpDefOp = inp.getDefiningOp();
+ if (nullptr != inpDefOp && !visited.contains(inpDefOp)) {
+ visited.insert(inpDefOp);
+ queue.push(inpDefOp);
+ }
+ }
+ }
+ }
+
+ clusters[shape] = std::move(cluster);
+}
+
+// Returns whether `op` is a shape.with_shape, or all the users' of `op`
+// eventually point to the shape operand of shape.with_shape ops
+bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
+ Operation *op, Value prevOutput) {
+ if (onlyUsedByWithShapes.contains(op))
+ return true;
+
+ if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
+ return withOp.getShape() == prevOutput;
+
+ if (op->use_empty())
+ return false;
+
+ for (Value oup : op->getResults())
+ for (Operation *user : oup.getUsers())
+ if (!calOnlyUsedByWithShapesRecursively(user, oup))
+ return false;
+
+ onlyUsedByWithShapes.insert(op);
+ return true;
+}
+
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createOutlineShapeComputationPass() {
+ return std::make_unique<OutlineShapeComputationPass>();
+}
diff --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
new file mode 100644
index 0000000000000..9e383af32d343
--- /dev/null
+++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
@@ -0,0 +1,208 @@
+// RUN: mlir-opt -outline-shape-computation -test-print-shape-mapping -split-input-file %s 2>&1 | FileCheck %s
+
+// Two dynamic shapes: one of direct shape.shape_of(arg) and the other.
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+ // CHECK-DAG: Shape for {{.*}} = "test.abs"({{.*}}> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+ // CHECK-DAG: Shape for {{.*}} = "test.concat"({{.*}}> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+ %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
+ %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+ %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
+ %4 = shape.value_of %3 : tensor<?x4x?xf32>
+ %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+ %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+ %7 = arith.addi %6, %c2 : index
+ %8 = shape.from_extents %7, %c4, %1 : index, index, index
+ %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
+ %10 = shape.value_of %9 : tensor<?x4x?xf32>
+ return %10 : tensor<?x4x?xf32>
+}
+
+// CHECK-LABEL: func.func @two_dynamic_one_direct_shape
+// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT: return %1 : tensor<?x4x?xf32>
+
+// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
+// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
+// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
+// CHECK-DAG: return %[[V4]] : !shape.shape
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
+// CHECK-DAG: %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+// CHECK-DAG: return %0 : tensor<3xindex>
+
+// -----
+
+// Two dynamic shapes and they share the same shape.func
+func.func @two_dynamic_share_same_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+ %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
+ %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+ %3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+ %4 = arith.addi %3, %c2 : index
+ %5 = shape.from_extents %4, %c4, %1 : index, index, index
+ %6 = shape.with_shape %2, %5 : tensor<?x4x?xf32>, !shape.shape
+ %7 = shape.value_of %6 : tensor<?x4x?xf32>
+ %8 = "test.abs"(%7) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+ %9 = shape.with_shape %8, %5 : tensor<?x4x?xf32>, !shape.shape
+ %10 = shape.value_of %9 : tensor<?x4x?xf32>
+ return %10 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: func.func @two_dynamic_share_same_shape
+// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT: return %1 : tensor<?x4x?xf32>
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
+// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
+// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
+// CHECK-DAG: return %4 : !shape.shape
+// CHECK-NOT: shape_cal_1
+
+// -----
+
+// There's an internal dynamic shape source, and two other dynamic shapes shares it
+func.func @internal_dynamic_shape_source_shared(%arg0: tensor<?x4xf32>) -> tensor<?xi32> {
+ %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+ %1 = shape.shape_of %0 : tensor<?xi32> -> tensor<1xindex>
+ %2 = shape.with_shape %0, %1 : tensor<?xi32>, tensor<1xindex>
+ %3 = shape.value_of %2 : tensor<?xi32>
+ %4 = "test.abs"(%3) : (tensor<?xi32>) -> tensor<?xi32>
+ %5 = shape.with_shape %4, %1 : tensor<?xi32>, tensor<1xindex>
+ %6 = shape.value_of %5 : tensor<?xi32>
+ %7 = "test.negate"(%6) : (tensor<?xi32>) -> tensor<?xi32>
+ %8 = shape.with_shape %7, %1 : tensor<?xi32>, tensor<1xindex>
+ %9 = shape.value_of %8 : tensor<?xi32>
+ return %9 : tensor<?xi32>
+}
+// CHECK-LABEL: func.func @internal_dynamic_shape_source_shared
+// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor<?xi32>) -> tensor<?xi32>
+// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor<?xi32>) -> tensor<?xi32>
+// CHECK-NEXT: return %2 : tensor<?xi32>
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?xi32>) -> tensor<1xindex> {
+// CHECK-NEXT: %0 = shape_of %arg0 : tensor<?xi32> -> tensor<1xindex>
+// CHECK-NEXT: return %0 : tensor<1xindex>
+// CHECK-NOT: shape_cal_1
+
+// -----
+
+// There's only a return op in the constructed shape.func
+func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
+ %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+ %1 = shape.with_shape %0, %arg1 : tensor<?xi32>, tensor<1xindex>
+ %2 = shape.value_of %1 : tensor<?xi32>
+ return %2 : tensor<?xi32>
+}
+// CHECK-LABEL: func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
+// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+// CHECK-NEXT: return %0 : tensor<?xi32>
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> {
+// CHECK-NEXT: return %arg0 : tensor<1xindex>
+
+// -----
+
+// Shape computation part interleaves with general computation.
+func.func @interleaved_shape_computation(%arg0: tensor<?x4x5xf32>, %arg1: tensor<?x4x5xf32>, %arg2: tensor<?x4x5xf32>) -> (tensor<?x4x5xf32>, index) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c5 = arith.constant 5 : index
+ %0 = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
+ %1 = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
+ %2 = shape.shape_of %arg2 : tensor<?x4x5xf32> -> tensor<3xindex>
+ %3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
+ %4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+ %5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index
+ %6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index
+ %7 = arith.addi %4, %5 : index
+ %8 = arith.addi %7, %6 : index
+ %9 = shape.from_extents %8, %c4, %c5 : index, index, index
+ %10 = shape.with_shape %3, %9 : tensor<?x4x5xf32>, !shape.shape
+ %11 = shape.value_of %10 : tensor<?x4x5xf32>
+ return %11, %7 : tensor<?x4x5xf32>, index
+}
+// CHECK-LABEL: func.func @interleaved_shape_computation
+// CHECK-DAG: %[[V0:.*]] = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
+// CHECK-DAG: %[[V1:.*]] = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
+// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
+// CHECK-DAG: %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index
+// CHECK-DAG: return %[[V2]], %[[V5]] : tensor<?x4x5xf32>, index
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x5xf32>, %arg1: index, %arg2: index) -> !shape.shape {
+// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
+// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index
+// CHECK-DAG: %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index
+// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index
+// CHECK-DAG: return %[[V3]] : !shape.shape
+
+// -----
+
+// There're multiple reused shape computations.
+func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %0 = shape.shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
+ %1 = shape.shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
+ %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+ %3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+ %4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index
+ %5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index
+ %6 = arith.addi %4, %5 : index
+ %7 = shape.from_extents %6, %c4 : index, index
+ %8 = shape.with_shape %2, %7 : tensor<?x4xf32>, !shape.shape
+ %9 = shape.with_shape %3, %7 : tensor<?x4xf32>, !shape.shape
+ %10 = shape.value_of %8 : tensor<?x4xf32>
+ %11 = shape.value_of %9 : tensor<?x4xf32>
+ %12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+ %13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+ %14 = arith.addi %6, %4 : index
+ %15 = shape.from_extents %14, %c4 : index, index
+ %16 = shape.with_shape %12, %15 : tensor<?x4xf32>, !shape.shape
+ %17 = shape.with_shape %13, %15 : tensor<?x4xf32>, !shape.shape
+ %18 = shape.value_of %16 : tensor<?x4xf32>
+ %19 = shape.value_of %17 : tensor<?x4xf32>
+ return %10, %11, %18, %19 : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
+}
+// CHECK-LABEL: func.func @multiple_reused
+// CHECK-DAG: %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG: %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG: %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
+
+// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
+// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
+// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index
+// CHECK-DAG: %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index
+// CHECK-DAG: return %[[V6]] : !shape.shape
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
+// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
+// CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index
+// CHECK-DAG: return %[[V5]] : !shape.shape
+
diff --git a/mlir/test/lib/Dialect/Shape/CMakeLists.txt b/mlir/test/lib/Dialect/Shape/CMakeLists.txt
index 2d142d3949c0a..545f13db25a84 100644
--- a/mlir/test/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shape/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRShapeTestPasses
TestShapeFunctions.cpp
+ TestShapeMappingAnalysis.cpp
EXCLUDE_FROM_LIBMLIR
@@ -11,6 +12,7 @@ add_mlir_library(MLIRShapeTestPasses
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
+ MLIRShapeOpsTransforms
MLIRShapeDialect
MLIRSupport
)
diff --git a/mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp b/mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp
new file mode 100644
index 0000000000000..f50988e3b3319
--- /dev/null
+++ b/mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp
@@ -0,0 +1,43 @@
+//===- TestShapeMappingInfo.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestShapeMappingPass
+ : public PassWrapper<TestShapeMappingPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShapeMappingPass)
+
+ StringRef getArgument() const final { return "test-print-shape-mapping"; }
+ StringRef getDescription() const final {
+ return "Print the contents of a constructed shape mapping information.";
+ }
+ void runOnOperation() override {
+ llvm::Optional<std::reference_wrapper<shape::ShapeMappingAnalysis>>
+ maybeAnalysis = getCachedAnalysis<shape::ShapeMappingAnalysis>();
+ if (maybeAnalysis.has_value())
+ maybeAnalysis.value().get().print(llvm::errs());
+ else
+ llvm::errs() << "No cached ShapeMappingAnalysis existed.";
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestShapeMappingPass() {
+ PassRegistration<TestShapeMappingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 373a048d091ec..37d331b1000d5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -109,6 +109,7 @@ void registerTestPDLLPasses();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
+void registerTestShapeMappingPass();
void registerTestSliceAnalysisPass();
void registerTestTensorTransforms();
void registerTestTilingInterface();
@@ -208,6 +209,7 @@ void registerTestPasses() {
mlir::test::registerTestPDLLPasses();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
+ mlir::test::registerTestShapeMappingPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTilingInterface();
More information about the Mlir-commits
mailing list