[Mlir-commits] [mlir] d40a19c - [mlir][linalg] Add pattern to push reshape after elementwise operation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 21 21:35:59 PDT 2021

Author: thomasraoux
Date: 2021-04-21T21:22:39-07:00
New Revision: d40a19c3a8b3da1a6be40f3b0b56075ed1e31e3a

URL: https://github.com/llvm/llvm-project/commit/d40a19c3a8b3da1a6be40f3b0b56075ed1e31e3a
DIFF: https://github.com/llvm/llvm-project/commit/d40a19c3a8b3da1a6be40f3b0b56075ed1e31e3a.diff

LOG: [mlir][linalg] Add pattern to push reshape after elementwise operation

This help expose more fusion opportunities.

Differential Revision: https://reviews.llvm.org/D100685




diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 884db2e939665..251a2f8e6d034 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -106,6 +106,10 @@ void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
     LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
+/// Patterns to push reshape op towards the end of the graph in order to expose
+/// more fusion opportunities.
+void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
 /// The permutation is expressed as a list of integers that specify

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4d6045a6d7b18..628b5969fe001 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -998,6 +998,161 @@ struct FoldProducerReshapeOpByLinearization
+static SmallVector<ReassociationIndices>
+getReassociationIndices(ArrayRef<AffineMap> maps) {
+  SmallVector<ReassociationIndices> reassociation;
+  for (AffineMap map : maps) {
+    ReassociationIndices indices;
+    for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+      unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
+      indices.push_back(pos);
+    }
+    reassociation.push_back(indices);
+  }
+  return reassociation;
+/// Pattern to move rank reducing reshape after an elementwise linalg generic
+/// op. This is useful to expose more fusion opportunities between named ops and
+/// generic op. This can only be done if there is no broadcast or permuation
+/// within the dimensions we need to merge.
+/// For example,
+///  %0 = linalg.tensor_reshape %A [
+///    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
+///  %2 = linalg.generic {indexing_maps = [
+///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///    affine_map<(d0, d1, d2) -> (d2)>,
+///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
+///    ["parallel", "parallel", "parallel"]} {
+///  } -> tensor<112x112x16xf32>
+///  into
+///  %2 = linalg.generic {indexing_maps = [
+///    affine_map<(d0, d1) -> (d0, d1)>,
+///    affine_map<(d0, d1) -> (d1)>,
+///    affine_map<(d0, d1) -> (d0, d1)>],
+///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
+///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
+///  } -> tensor<12544x16xf32>
+///  %3 = linalg.tensor_reshape %2 [
+///    #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
+template <typename GenericOpTy>
+struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
+  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Only apply to elementwise linalg on tensor.
+    if (!op.hasTensorSemantics() ||
+        op.getNumParallelLoops() != op.getNumLoops())
+      return failure();
+    // Only support identity output maps. It could be extended to permuations if
+    // needed.
+    if (llvm::any_of(op.getOutputIndexingMaps(),
+                     [](AffineMap map) { return !map.isIdentity(); }))
+      return failure();
+    int64_t destRank = op.getNumParallelLoops();
+    SmallVector<Value, 4> newOperands = llvm::to_vector<4>(op.getInputs());
+    TensorReshapeOp reshapeFound;
+    // 1. Look for tensor_reshape operands and figure out save the dimensions
+    // merged.
+    for (auto operand : llvm::enumerate(op.getInputs())) {
+      TensorReshapeOp reshapeOp =
+          operand.value().template getDefiningOp<TensorReshapeOp>();
+      if (!reshapeOp || reshapeOp.getSrcType().getRank() >
+                            reshapeOp.getResultType().getRank()) {
+        continue;
+      }
+      // TODO: We could support non-identity map as long as the merged
+      // dimensions are still contiguous.
+      if (!op.getIndexingMaps()[operand.index()].isIdentity())
+        continue;
+      if (reshapeFound) {
+        // Only support a second reshape op if it has the same reassociate maps.
+        if (reshapeFound.getReassociationMaps() ==
+            reshapeOp.getReassociationMaps())
+          newOperands[operand.index()] = reshapeOp.src();
+        continue;
+      }
+      reshapeFound = reshapeOp;
+      newOperands[operand.index()] = reshapeOp.src();
+    }
+    if (!reshapeFound)
+      return failure();
+    // Calculate the reassociation indices and rassociated reverse map.
+    SmallVector<ReassociationIndices> reassociation =
+        getReassociationIndices(reshapeFound.getReassociationMaps());
+    SmallVector<unsigned, 4> remap(destRank);
+    for (auto &indices : llvm::enumerate(reassociation)) {
+      for (int64_t index : indices.value()) {
+        remap[index] = indices.index();
+      }
+    }
+    // 2. Verify that we can merge the dimensions in the linalg and that we
+    // don't need to create new reshapes operands. Inserting new reshape
+    // operands would defeat the purpose of the transformation.
+    for (auto operand : llvm::enumerate(op.getInputs())) {
+      if (operand.value() == newOperands[operand.index()]) {
+        AffineMap map = op.getIndexingMaps()[operand.index()];
+        for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
+          if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
+            return failure();
+        }
+      }
+    }
+    // 3. Calculate the affine map remapping and the reassociation to apply to
+    // output tensors.
+    SmallVector<AffineMap, 4> newMaps;
+    unsigned newRank = reassociation.size();
+    for (auto map : op.getIndexingMaps()) {
+      SmallVector<AffineExpr> newExprs;
+      for (auto expr : map.getResults()) {
+        unsigned position = expr.template cast<AffineDimExpr>().getPosition();
+        // Skip dimension merged except for the last of the group.
+        if (reassociation[remap[position]].back() == position) {
+          newExprs.push_back(
+              getAffineDimExpr(remap[position], op.getContext()));
+        }
+      }
+      newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext()));
+    }
+    // 4. Reshape the output tensors.
+    SmallVector<Value> newOutputs;
+    SmallVector<Type> newOutputTypes;
+    for (auto output : op.outputs()) {
+      Value newOutput = rewriter.create<TensorReshapeOp>(
+          op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
+      newOutputTypes.push_back(newOutput.getType());
+      newOutputs.push_back(newOutput);
+    }
+    // 5. Create a new generic op with lowerer rank.
+    SmallVector<StringRef, 4> iteratorTypes(newRank,
+                                            getParallelIteratorTypeName());
+    auto newOp =
+        rewriter.create<GenericOpTy>(op->getLoc(), newOutputTypes, newOperands,
+                                     newOutputs, newMaps, iteratorTypes);
+    rewriter.inlineRegionBefore(op.region(), newOp.region(),
+                                newOp.region().begin());
+    // 6. Reshape the so that the type matches the uses.
+    SmallVector<Value> newResults;
+    for (auto result : llvm::enumerate(newOp->getResults())) {
+      newResults.push_back(rewriter.create<TensorReshapeOp>(
+          op->getLoc(), op.getOutputTensorTypes()[result.index()],
+          result.value(), reassociation));
+    }
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
 /// Pattern to fuse a tensor_reshape op with its consumer
 /// generic/indexed_generic op, when the reshape op is collapsing
 /// dimensions. The dimensionality of the loop in the consumer is expanded.
@@ -1333,6 +1488,12 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.add<PushExpandingReshape<GenericOp>,
+               PushExpandingReshape<IndexedGenericOp>>(context);
 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
   return std::make_unique<FusionOfTensorOpsPass>();

diff  --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
new file mode 100644
index 0000000000000..bddf0d68749bc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-LABEL: func @reshape
+// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
+//      CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x112x16xf32> into tensor<?x16xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
+//      CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x16xf32> into tensor<?x112x16xf32>
+//      CHECK: return %[[RR]] : tensor<?x112x16xf32>
+func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
+  %0 = linalg.tensor_reshape %A [
+    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+      : tensor<?x16xf32> into tensor<?x112x16xf32>
+  %2 = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+  ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>)
+  outs(%init : tensor<?x112x16xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
+    %s = subf %arg1, %arg2 : f32
+    linalg.yield %s : f32
+  } -> tensor<?x112x16xf32>
+  return %2 : tensor<?x112x16xf32>
+// -----
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-LABEL: func @reshape_multiple
+// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
+//      CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
+//      CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] [#[[$MAP0]], #[[$MAP1]]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
+//      CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
+//      CHECK: return %[[RR]] : tensor<112x112x16xf32>
+func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
+  %C: tensor<16xf32>) -> tensor<112x112x16xf32> {
+  %0 = linalg.tensor_reshape %A [
+    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+      : tensor<12544x16xf32> into tensor<112x112x16xf32>
+  %1 = linalg.tensor_reshape %B [
+    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+      : tensor<12544x16xf32> into tensor<112x112x16xf32>
+  %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
+  %3 = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+  ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>)
+  outs(%2 : tensor<112x112x16xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
+    %s = subf %arg1, %arg2 : f32
+    %m = mulf %s, %arg3 : f32
+    linalg.yield %m : f32
+  } -> tensor<112x112x16xf32>
+  return %3 : tensor<112x112x16xf32>
+// -----
+// Negative test, since the second source is broadcasted from d1 we cannot merge
+// d0 and d1 dimensions
+// CHECK-LABEL: func @reshape_negative
+// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: linalg.generic
+// CHECK: } -> tensor<112x112x16xf32>
+func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
+  %20 = linalg.tensor_reshape %A [
+    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+      : tensor<12544x16xf32> into tensor<112x112x16xf32>
+  %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
+  %22 = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+  ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>)
+  outs(%21 : tensor<112x112x16xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
+    %s = subf %arg1, %arg2 : f32
+    linalg.yield %s : f32
+  } -> tensor<112x112x16xf32>
+  return %22 : tensor<112x112x16xf32>

diff  --git a/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
index bfd7344708989..d0812ab8ec0dd 100644
--- a/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
@@ -66,6 +66,22 @@ struct TestLinalgElementwiseFusion
+struct TestPushExpandingReshape
+    : public PassWrapper<TestPushExpandingReshape, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+  void runOnFunction() override {
+    MLIRContext *context = &this->getContext();
+    FuncOp funcOp = this->getFunction();
+    RewritePatternSet patterns(context);
+    linalg::populatePushReshapeOpsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+  }
 } // namespace
 namespace test {
@@ -74,6 +90,11 @@ void registerTestLinalgElementwiseFusion() {
       "Test Linalg element wise operation fusion patterns");
+void registerTestPushExpandingReshape() {
+  PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
+      "test-linalg-push-reshape", "Test Linalg reshape push patterns");
 } // namespace test
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eea5d8f494221..009e12eb9174b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -78,6 +78,7 @@ void registerTestIRVisitorsPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
 void registerTestLinalgElementwiseFusion();
+void registerTestPushExpandingReshape();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
 void registerTestLinalgGreedyFusion();
@@ -156,6 +157,7 @@ void registerTestPasses() {
+  test::registerTestPushExpandingReshape();


More information about the Mlir-commits mailing list