[Mlir-commits] [mlir] 9f77909 - [mlir][shape] add outline-shape-computation pass

Jacques Pienaar llvmlistbot at llvm.org
Sun Oct 2 20:24:59 PDT 2022


Author: Yuanqiang Liu
Date: 2022-10-02T20:24:49-07:00
New Revision: 9f77909a5e07b7973fe13d4ea6391c29ff1b46b5

URL: https://github.com/llvm/llvm-project/commit/9f77909a5e07b7973fe13d4ea6391c29ff1b46b5
DIFF: https://github.com/llvm/llvm-project/commit/9f77909a5e07b7973fe13d4ea6391c29ff1b46b5.diff

LOG: [mlir][shape] add outline-shape-computation pass

Add outline-shape-computation pass. This pass his pass outlines the
shape computation part in high level IR by adding shape.func and
populate corresponding mapping information into ShapeMappingAnalysis.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D131810

Added: 
    mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h
    mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
    mlir/test/Dialect/Shape/outline-shape-computation.mlir
    mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp

Modified: 
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
    mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Shape/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h b/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h
new file mode 100644
index 0000000000000..25befa3590856
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h
@@ -0,0 +1,60 @@
+//===- ShapeMappingAnalysis.h - Preserve shape Info  ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
+#define MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+namespace shape {
+
+/// ShapeMappingValue works as the value of ShapeMappingAnalysis table, where
+/// `funcSymbol` is the symbol of mapping function, and `inputs` are the actual
+/// parameters for the function.
+struct ShapeMappingValue {
+  ShapeMappingValue() = default;
+  ShapeMappingValue(FlatSymbolRefAttr symbol, llvm::SmallVector<Value> &&inps)
+      : funcSymbol(symbol), inputs(inps) {}
+
+  FlatSymbolRefAttr funcSymbol;
+  llvm::SmallVector<Value> inputs;
+};
+
+/// ShapeMappingAnalysis is used together with OutlineShapeComputationPass to
+/// preserve Value and corresponding shape function / arguments mapping
+/// information
+struct ShapeMappingAnalysis {
+  ShapeMappingAnalysis(Operation *op) : operation(op) { (void)operation; }
+
+  /// Dumps the shape mapping information to the given stream.
+  void print(raw_ostream &os) const {
+    os << "// ---- Shape Mapping Information -----\n";
+    for (const auto &it : shapeMapping) {
+      const ShapeMappingValue &mappingValue = it.second;
+      os << "// Shape for " << it.first << " :: " << mappingValue.funcSymbol;
+      llvm::interleaveComma(mappingValue.inputs, os << "(");
+      os << ")\n";
+    }
+  }
+
+  llvm::DenseMap<Value, ShapeMappingValue> shapeMapping;
+
+private:
+  Operation *operation;
+};
+
+} // namespace shape
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index be3c74123d065..cfb637f133f54 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -18,6 +18,7 @@
 
 namespace mlir {
 class ConversionTarget;
+class ModuleOp;
 class TypeConverter;
 namespace func {
 class FuncOp;
@@ -53,6 +54,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
 // level.
 std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass();
 
+/// Outline the shape computation part by adding shape.func and populate
+/// conrresponding mapping infomation into ShapeMappingAnalysis.
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 503780f0d1c7c..9dfda9ea33615 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -11,6 +11,88 @@
 
 include "mlir/Pass/PassBase.td"
 
+def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> {
+  let summary = "Using shape.func to preserve shape computation";
+  let description = [{
+    This pass outlines the shape computation part in high level IR by adding
+    shape.func and populate corresponding mapping infoemation into
+    ShapeMappingAnalysis. The shape computation part is usually introduced by
+    shape reification, and each single dynamic shape is denoted by shape.with_shape.
+
+    There're two main reasons this shape-outline pass is needed:
+    1. Many passes don't take shape reification part into consideration.
+       Therefore we need to "remove" the shape reification part temporarily for
+       these passes.
+    2. Sometimes we cannot redo shape reification after converting from dialect
+       A to dialect B. Because op-level shape reification is only implemented
+       on A.
+
+    Input:
+
+    ```mlir
+    func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
+      tensor<?x4x?xf32> {
+      %c2 = arith.constant 2 : index
+      %c0 = arith.constant 0 : index
+      %c4 = arith.constant 4 : index
+      %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+      %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
+      %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+      %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
+      %4 = shape.value_of %3 : tensor<?x4x?xf32>
+      %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
+            tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+      %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+      %7 = arith.addi %6, %c2 : index
+      %8 = shape.from_extents %7, %c4, %1 : index, index, index
+      %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
+      %10 = shape.value_of %9 : tensor<?x4x?xf32>
+      return %10 : tensor<?x4x?xf32>
+    }
+    ```
+
+    Output
+    ```mlir
+    func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
+      tensor<?x4x?xf32> {
+      %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+      %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
+            tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+      return %1 : tensor<?x4x?xf32>
+    }
+    shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
+      %c2 = arith.constant 2 : index
+      %c0 = arith.constant 0 : index
+      %c4 = arith.constant 4 : index
+      %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+      %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index
+      %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index
+      %3 = arith.addi %2, %c2 : index
+      %4 = from_extents %3, %c4, %1 : index, index, index
+      return %4 : !shape.shape
+    }
+    shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
+      %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+      return %0 : tensor<3xindex>
+    }
+    ```
+
+    For the above example, the shape computation is inlined in the input IR,
+    which is used for two values' (test.abs and test.concat) shape. And the shape
+    compuatation part is outlined in the output IR.
+
+    And the shape mapping infomation will be:
+
+    ```
+    // ---- Shape Mapping Infomation -----
+    // - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+    // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+    ```
+  }];
+  let constructor = "mlir::createOutlineShapeComputationPass()";
+  let dependentDialects = ["shape::ShapeDialect"];
+}
+
 def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
   let summary = "Replace all cstr_ ops with a true witness";
   let constructor = "mlir::createRemoveShapeConstraintsPass()";

diff  --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index 40f1bd1941347..7c9b0d2e5e3a8 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRShapeOpsTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  OutlineShapeComputation.cpp
   RemoveShapeConstraints.cpp
   ShapeToShapeLowering.cpp
 

diff  --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
new file mode 100644
index 0000000000000..5d598a6f88109
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -0,0 +1,318 @@
+//====----- OutlineShapeComputation.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Debug.h"
+#include <queue>
+#include <unordered_set>
+#include <vector>
+
+namespace mlir {
+#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "outline-shape-computation"
+
+using namespace mlir;
+
+namespace {
+
+// A Value is an input of the cluster if it is an operand of an operation in the
+// cluster and its defining operation is not in the cluster.
+SmallVector<Value, 4>
+getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
+  SmallVector<Value, 4> inputs;
+  llvm::SmallDenseSet<Value> inputSet;
+  llvm::SmallDenseSet<Operation *> opSet;
+  for (Operation *op : cluster) {
+    bool inserted = opSet.insert(op).second;
+    (void)inserted;
+    assert(inserted && "cluster contains duplicate operations");
+  }
+
+  for (Operation *op : cluster) {
+    for (Value operand : op->getOperands()) {
+      Operation *operandOp = operand.getDefiningOp();
+      if (opSet.find(operandOp) != opSet.end()) {
+        // Skip if defining op is in the cluster.
+        continue;
+      }
+      if (inputSet.insert(operand).second)
+        inputs.push_back(operand);
+    }
+  }
+  return inputs;
+}
+
+// Create a shape.func representing the shape computation for `shape`.
+std::pair<shape::FuncOp, SmallVector<Value>>
+createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
+                      Value shape, StringRef fnName, Location loc) {
+  SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
+  auto fnType =
+      cluster.empty()
+          ? b.getFunctionType(shape.getType(), shape.getType())
+          : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
+  shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
+  Block *block = fnOp.addEntryBlock();
+  b.setInsertionPoint(block, block->end());
+  BlockAndValueMapping bvm;
+  if (cluster.empty()) {
+    bvm.map(shape, fnOp.getArgument(0));
+  } else {
+    for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
+      bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
+  }
+
+  for (Operation *op : cluster)
+    b.clone(*op, bvm);
+  llvm::SmallVector<Value, 4> fnReturns;
+  fnReturns.push_back(bvm.lookupOrDefault(shape));
+
+  b.create<shape::ReturnOp>(loc, fnReturns);
+  fnOp.setPrivate();
+  return std::make_pair(fnOp, inputs);
+}
+
+// The operations in the cluster might be unsorted, which could be inconvenient
+// when creating shape.func op.
+DenseMap<Value, SmallVector<Operation *, 8>>
+getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
+                   func::FuncOp funcOp) {
+  // Compute all clusters that each operation is in
+  DenseMap<Operation *, SmallVector<Value>> op2Shapes;
+  for (const auto &it : clusters) {
+    Value shape = it.first;
+    const DenseSet<Operation *> &cluster = it.second;
+    for (Operation *cOp : cluster)
+      op2Shapes[cOp].push_back(shape);
+  }
+
+  // Iterate through all operations in order. Get all the clusters `cOp` belongs
+  // to and construct the new ordered cluster as it traverses.
+  DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
+  funcOp.walk([&](Operation *op) {
+    auto it = op2Shapes.find(op);
+    if (it != op2Shapes.end()) {
+      Operation *cOp = it->first;
+      for (Value shape : it->second)
+        orderedClusters[shape].push_back(cOp);
+    }
+  });
+
+  return orderedClusters;
+}
+
+void constructShapeFunc(
+    const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
+    DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
+    SymbolTable &symbolTable,
+    DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
+    func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
+  std::string shapeCalculationNamePrefix = "shape_cal_";
+  int shapeCalculationNameIdx = 0;
+  OpBuilder builder(context);
+
+  // Construct a shape function
+  for (shape::WithOp withOp : allWithOps) {
+    Value value = withOp.getOperand();
+    Value shape = withOp.getShape();
+    RankedTensorType rankedType = value.getType().dyn_cast<RankedTensorType>();
+    if (rankedType == nullptr)
+      continue;
+
+    const SmallVector<Operation *, 8> &cluster = clusters[shape];
+    shape::ShapeMappingValue shapeMappingValue;
+    auto it = dynShape2ShapeFunc.find(shape);
+    if (it == dynShape2ShapeFunc.end()) {
+      std::string name = shapeCalculationNamePrefix +
+                         std::to_string(shapeCalculationNameIdx++);
+      Location loc = value.getLoc();
+      builder.setInsertionPointAfter(funcOp);
+      auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
+      const SmallVector<Value> &inputs = pair.second;
+      shape::FuncOp shapeFuncOp = pair.first;
+      StringAttr insertedName = symbolTable.insert(shapeFuncOp);
+      auto symbol = FlatSymbolRefAttr::get(context, insertedName);
+
+      shapeMappingValue.funcSymbol = symbol;
+      shapeMappingValue.inputs = inputs;
+    } else {
+      shapeMappingValue = it->second;
+    }
+    dynShape2ShapeFunc[shape] = shapeMappingValue;
+    shapeMappingAnalysis.shapeMapping.insert(
+        std::make_pair(value, shapeMappingValue));
+  }
+}
+
+struct OutlineShapeComputationPass
+    : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
+
+  void runOnOperation() override;
+
+private:
+  bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
+
+  void getClusterFromValue(Value shape,
+                           DenseMap<Value, DenseSet<Operation *>> &clusters);
+
+  DenseMap<Value, SmallVector<Operation *, 8>>
+  constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
+                                func::FuncOp funcOp);
+
+  DenseSet<Operation *> onlyUsedByWithShapes;
+};
+
+class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp op,
+                                PatternRewriter &rewriter) const override {
+    auto shapeOf =
+        rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
+    rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
+                                                    op.getIndex());
+    return success();
+  }
+};
+
+void OutlineShapeComputationPass::runOnOperation() {
+  ModuleOp moduleOp = getOperation();
+  SymbolTable symbolTable(moduleOp);
+  DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
+  auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
+  // TODO: This is as we populate this analysis during a pass that mutates. This
+  // pass currently requires 1 single module being compiled.
+  shapeMappingAnalysis.shapeMapping.clear();
+  markAnalysesPreserved<shape::ShapeMappingAnalysis>();
+
+  moduleOp.walk([&](func::FuncOp funcOp) {
+    MLIRContext *context = funcOp.getContext();
+    RewritePatternSet prevPatterns(context);
+    prevPatterns.insert<TensorDimOpRewriter>(context);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
+      return signalPassFailure();
+
+    // initialize class member `onlyUsedByWithShapes`
+    onlyUsedByWithShapes.clear();
+    funcOp.walk([&](Operation *op) {
+      calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
+    });
+    LLVM_DEBUG({
+      llvm::dbgs() << "onlyUsedByWithShapes table: \n";
+      for (auto it : onlyUsedByWithShapes)
+        llvm::dbgs() << *it << "\n";
+    });
+
+    // collect all the shape.with_shape ops.
+    std::vector<shape::WithOp> allWithOps;
+    funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
+
+    DenseMap<Value, SmallVector<Operation *, 8>> clusters =
+        constructClustersForEachShape(allWithOps, funcOp);
+    constructShapeFunc(allWithOps, context, clusters, symbolTable,
+                       dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
+
+    for (shape::WithOp withOp : allWithOps) {
+      Value value = withOp.getOperand();
+      for (Operation *user : withOp.getResult().getUsers()) {
+        if (Value valueOf = llvm::dyn_cast<shape::ValueOfOp>(user))
+          valueOf.replaceAllUsesExcept(value, withOp);
+      }
+    }
+
+    // Apply patterns, note this also performs DCE.
+    if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
+      return signalPassFailure();
+  });
+}
+
+DenseMap<Value, SmallVector<Operation *, 8>>
+OutlineShapeComputationPass::constructClustersForEachShape(
+    const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
+  DenseMap<Value, DenseSet<Operation *>> clusters;
+  for (shape::WithOp withOp : allWithOps) {
+    Value shape = withOp.getShape();
+    if (clusters.count(shape) == 0)
+      getClusterFromValue(shape, clusters);
+  }
+  return getOrderedClusters(clusters, funcOp);
+}
+
+// The output of a cluster is the `shape`, and the inputs are the outputs of
+// operations who are not in `onlyUsedByWithShapes`
+void OutlineShapeComputationPass::getClusterFromValue(
+    Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
+  DenseSet<Operation *> cluster;
+
+  DenseSet<Operation *> visited;
+  std::queue<Operation *> queue;
+
+  // defOp == nullptr means shape is the argument of the func op
+  if (Operation *defOp = shape.getDefiningOp()) {
+    visited.insert(defOp);
+    queue.push(defOp);
+  }
+  while (!queue.empty()) {
+    Operation *op = queue.front();
+    queue.pop();
+    if (onlyUsedByWithShapes.contains(op)) {
+      cluster.insert(op);
+      for (Value inp : op->getOperands()) {
+        Operation *inpDefOp = inp.getDefiningOp();
+        if (nullptr != inpDefOp && !visited.contains(inpDefOp)) {
+          visited.insert(inpDefOp);
+          queue.push(inpDefOp);
+        }
+      }
+    }
+  }
+
+  clusters[shape] = std::move(cluster);
+}
+
+// Returns whether `op` is a shape.with_shape, or all the users' of `op`
+// eventually point to the shape operand of shape.with_shape ops
+bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
+    Operation *op, Value prevOutput) {
+  if (onlyUsedByWithShapes.contains(op))
+    return true;
+
+  if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
+    return withOp.getShape() == prevOutput;
+
+  if (op->use_empty())
+    return false;
+
+  for (Value oup : op->getResults())
+    for (Operation *user : oup.getUsers())
+      if (!calOnlyUsedByWithShapesRecursively(user, oup))
+        return false;
+
+  onlyUsedByWithShapes.insert(op);
+  return true;
+}
+
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createOutlineShapeComputationPass() {
+  return std::make_unique<OutlineShapeComputationPass>();
+}

diff  --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
new file mode 100644
index 0000000000000..9e383af32d343
--- /dev/null
+++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
@@ -0,0 +1,208 @@
+// RUN: mlir-opt -outline-shape-computation -test-print-shape-mapping -split-input-file %s 2>&1 | FileCheck %s
+
+// Two dynamic shapes: one of direct shape.shape_of(arg) and the other.
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  // CHECK-DAG: Shape for {{.*}} = "test.abs"({{.*}}> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+  // CHECK-DAG: Shape for {{.*}} = "test.concat"({{.*}}> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+  %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
+  %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+  %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
+  %4 = shape.value_of %3 : tensor<?x4x?xf32>
+  %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+  %7 = arith.addi %6, %c2 : index
+  %8 = shape.from_extents %7, %c4, %1 : index, index, index
+  %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
+  %10 = shape.value_of %9 : tensor<?x4x?xf32>
+  return %10 : tensor<?x4x?xf32>
+}
+
+// CHECK-LABEL:  func.func @two_dynamic_one_direct_shape
+// CHECK-NEXT:     %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT:     %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT:     return %1 : tensor<?x4x?xf32>
+
+// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
+// CHECK-DAG:      %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+// CHECK-DAG:      %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
+// CHECK-DAG:      %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG:      %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
+// CHECK-DAG:      %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
+// CHECK-DAG:      return %[[V4]] : !shape.shape
+
+// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
+// CHECK-DAG:   %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+// CHECK-DAG:   return %0 : tensor<3xindex>
+
+// -----
+
+// Two dynamic shapes and they share the same shape.func
+func.func @two_dynamic_share_same_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+  %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
+  %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  %3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+  %4 = arith.addi %3, %c2 : index
+  %5 = shape.from_extents %4, %c4, %1 : index, index, index
+  %6 = shape.with_shape %2, %5 : tensor<?x4x?xf32>, !shape.shape
+  %7 = shape.value_of %6 : tensor<?x4x?xf32>
+  %8 = "test.abs"(%7) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+  %9 = shape.with_shape %8, %5 : tensor<?x4x?xf32>, !shape.shape
+  %10 = shape.value_of %9 : tensor<?x4x?xf32>
+  return %10 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: func.func @two_dynamic_share_same_shape
+// CHECK-NEXT:     %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT:     %1 = "test.abs"(%0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
+// CHECK-NEXT:     return %1 : tensor<?x4x?xf32>
+
+// CHECK:       shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
+// CHECK-DAG:     %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
+// CHECK-DAG:     %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
+// CHECK-DAG:     %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG:     %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
+// CHECK-DAG:     %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
+// CHECK-DAG:     return %4 : !shape.shape
+// CHECK-NOT: shape_cal_1
+
+// -----
+
+// There's an internal dynamic shape source, and two other dynamic shapes shares it
+func.func @internal_dynamic_shape_source_shared(%arg0: tensor<?x4xf32>) -> tensor<?xi32> {
+  %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+  %1 = shape.shape_of %0 : tensor<?xi32> -> tensor<1xindex>
+  %2 = shape.with_shape %0, %1 : tensor<?xi32>, tensor<1xindex>
+  %3 = shape.value_of %2 : tensor<?xi32>
+  %4 = "test.abs"(%3) : (tensor<?xi32>) -> tensor<?xi32>
+  %5 = shape.with_shape %4, %1 : tensor<?xi32>, tensor<1xindex>
+  %6 = shape.value_of %5 : tensor<?xi32>
+  %7 = "test.negate"(%6) : (tensor<?xi32>) -> tensor<?xi32>
+  %8 = shape.with_shape %7, %1 : tensor<?xi32>, tensor<1xindex>
+  %9 = shape.value_of %8 : tensor<?xi32>
+  return %9 : tensor<?xi32>
+}
+// CHECK-LABEL: func.func @internal_dynamic_shape_source_shared
+// CHECK-NEXT:     %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+// CHECK-NEXT:     %1 = "test.abs"(%0) : (tensor<?xi32>) -> tensor<?xi32>
+// CHECK-NEXT:     %2 = "test.negate"(%1) : (tensor<?xi32>) -> tensor<?xi32>
+// CHECK-NEXT:     return %2 : tensor<?xi32>
+
+// CHECK:      shape.func private @shape_cal_0(%arg0: tensor<?xi32>) -> tensor<1xindex> {
+// CHECK-NEXT:   %0 = shape_of %arg0 : tensor<?xi32> -> tensor<1xindex>
+// CHECK-NEXT:   return %0 : tensor<1xindex>
+// CHECK-NOT: shape_cal_1
+
+// -----
+
+// There's only a return op in the constructed shape.func
+func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
+  %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+  %1 = shape.with_shape %0, %arg1 : tensor<?xi32>, tensor<1xindex>
+  %2 = shape.value_of %1 : tensor<?xi32>
+  return %2 : tensor<?xi32>
+}
+// CHECK-LABEL: func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
+// CHECK-NEXT:   %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
+// CHECK-NEXT:   return %0 : tensor<?xi32>
+
+// CHECK:      shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> {
+// CHECK-NEXT:   return %arg0 : tensor<1xindex>
+
+// -----
+
+// Shape computation part interleaves with general computation.
+func.func @interleaved_shape_computation(%arg0: tensor<?x4x5xf32>, %arg1: tensor<?x4x5xf32>, %arg2: tensor<?x4x5xf32>) -> (tensor<?x4x5xf32>, index) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c5 = arith.constant 5 : index
+  %0 = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
+  %1 = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
+  %2 = shape.shape_of %arg2 : tensor<?x4x5xf32> -> tensor<3xindex>
+  %3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
+  %4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
+  %5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index
+  %6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index
+  %7 = arith.addi %4, %5 : index
+  %8 = arith.addi %7, %6 : index
+  %9 = shape.from_extents %8, %c4, %c5 : index, index, index
+  %10 = shape.with_shape %3, %9 : tensor<?x4x5xf32>, !shape.shape
+  %11 = shape.value_of %10 : tensor<?x4x5xf32>
+  return %11, %7 : tensor<?x4x5xf32>, index
+}
+// CHECK-LABEL: func.func @interleaved_shape_computation
+// CHECK-DAG:   %[[V0:.*]] = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
+// CHECK-DAG:   %[[V1:.*]] = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
+// CHECK-DAG:   %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
+// CHECK-DAG:   %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG:   %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index
+// CHECK-DAG:   %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index
+// CHECK-DAG:   return %[[V2]], %[[V5]] : tensor<?x4x5xf32>, index
+
+// CHECK:     shape.func private @shape_cal_0(%arg0: tensor<?x4x5xf32>, %arg1: index, %arg2: index) -> !shape.shape {
+// CHECK-DAG:   %[[V0:.*]] = shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
+// CHECK-DAG:   %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index
+// CHECK-DAG:   %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index
+// CHECK-DAG:   %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index
+// CHECK-DAG:   return %[[V3]] : !shape.shape
+
+// -----
+
+// There're multiple reused shape computations.
+func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %0 = shape.shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
+  %1 = shape.shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
+  %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+  %3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+  %4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index
+  %5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index
+  %6 = arith.addi %4, %5 : index
+  %7 = shape.from_extents %6, %c4 : index, index
+  %8 = shape.with_shape %2, %7 : tensor<?x4xf32>, !shape.shape
+  %9 = shape.with_shape %3, %7 : tensor<?x4xf32>, !shape.shape
+  %10 = shape.value_of %8 : tensor<?x4xf32>
+  %11 = shape.value_of %9 : tensor<?x4xf32>
+  %12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+  %13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+  %14 = arith.addi %6, %4 : index
+  %15 = shape.from_extents %14, %c4 : index, index
+  %16 = shape.with_shape %12, %15 : tensor<?x4xf32>, !shape.shape
+  %17 = shape.with_shape %13, %15 : tensor<?x4xf32>, !shape.shape
+  %18 = shape.value_of %16 : tensor<?x4xf32>
+  %19 = shape.value_of %17 : tensor<?x4xf32>
+  return %10, %11, %18, %19 : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
+}
+// CHECK-LABEL: func.func @multiple_reused
+// CHECK-DAG:     %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG:     %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG:     %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG:     %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+// CHECK-DAG:     return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
+
+// CHECK:      shape.func private @shape_cal_1(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
+// CHECK-DAG:    %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG:    %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG:    %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG:    %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG:    %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
+// CHECK-DAG:    %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index
+// CHECK-DAG:    %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index
+// CHECK-DAG:    return %[[V6]] : !shape.shape
+
+// CHECK:     shape.func private @shape_cal_0(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
+// CHECK-DAG:   %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG:   %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
+// CHECK-DAG:   %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG:   %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
+// CHECK-DAG:   %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
+// CHECK-DAG:   %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index
+// CHECK-DAG:   return %[[V5]] : !shape.shape
+

diff  --git a/mlir/test/lib/Dialect/Shape/CMakeLists.txt b/mlir/test/lib/Dialect/Shape/CMakeLists.txt
index 2d142d3949c0a..545f13db25a84 100644
--- a/mlir/test/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shape/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRShapeTestPasses
   TestShapeFunctions.cpp
+  TestShapeMappingAnalysis.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
@@ -11,6 +12,7 @@ add_mlir_library(MLIRShapeTestPasses
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRPass
+  MLIRShapeOpsTransforms
   MLIRShapeDialect
   MLIRSupport
   )

diff  --git a/mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp b/mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp
new file mode 100644
index 0000000000000..f50988e3b3319
--- /dev/null
+++ b/mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp
@@ -0,0 +1,43 @@
+//===- TestShapeMappingInfo.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestShapeMappingPass
+    : public PassWrapper<TestShapeMappingPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShapeMappingPass)
+
+  StringRef getArgument() const final { return "test-print-shape-mapping"; }
+  StringRef getDescription() const final {
+    return "Print the contents of a constructed shape mapping information.";
+  }
+  void runOnOperation() override {
+    llvm::Optional<std::reference_wrapper<shape::ShapeMappingAnalysis>>
+        maybeAnalysis = getCachedAnalysis<shape::ShapeMappingAnalysis>();
+    if (maybeAnalysis.has_value())
+      maybeAnalysis.value().get().print(llvm::errs());
+    else
+      llvm::errs() << "No cached ShapeMappingAnalysis existed.";
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestShapeMappingPass() {
+  PassRegistration<TestShapeMappingPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 373a048d091ec..37d331b1000d5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -109,6 +109,7 @@ void registerTestPDLLPasses();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
+void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorTransforms();
 void registerTestTilingInterface();
@@ -208,6 +209,7 @@ void registerTestPasses() {
   mlir::test::registerTestPDLLPasses();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
+  mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestTilingInterface();


        


More information about the Mlir-commits mailing list