[Mlir-commits] [mlir] d40a19c - [mlir][linalg] Add pattern to push reshape after elementwise operation
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 21 21:35:59 PDT 2021
Author: thomasraoux
Date: 2021-04-21T21:22:39-07:00
New Revision: d40a19c3a8b3da1a6be40f3b0b56075ed1e31e3a
URL: https://github.com/llvm/llvm-project/commit/d40a19c3a8b3da1a6be40f3b0b56075ed1e31e3a
DIFF: https://github.com/llvm/llvm-project/commit/d40a19c3a8b3da1a6be40f3b0b56075ed1e31e3a.diff
LOG: [mlir][linalg] Add pattern to push reshape after elementwise operation
This help expose more fusion opportunities.
Differential Revision: https://reviews.llvm.org/D100685
Added:
mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 884db2e939665..251a2f8e6d034 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -106,6 +106,10 @@ void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
+/// Patterns to push reshape op towards the end of the graph in order to expose
+/// more fusion opportunities.
+void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
+
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4d6045a6d7b18..628b5969fe001 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -998,6 +998,161 @@ struct FoldProducerReshapeOpByLinearization
}
};
+static SmallVector<ReassociationIndices>
+getReassociationIndices(ArrayRef<AffineMap> maps) {
+ SmallVector<ReassociationIndices> reassociation;
+ for (AffineMap map : maps) {
+ ReassociationIndices indices;
+ for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+ unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ indices.push_back(pos);
+ }
+ reassociation.push_back(indices);
+ }
+ return reassociation;
+}
+
+/// Pattern to move rank reducing reshape after an elementwise linalg generic
+/// op. This is useful to expose more fusion opportunities between named ops and
+/// generic op. This can only be done if there is no broadcast or permuation
+/// within the dimensions we need to merge.
+///
+/// For example,
+///
+/// %0 = linalg.tensor_reshape %A [
+/// affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
+/// %2 = linalg.generic {indexing_maps = [
+/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+/// affine_map<(d0, d1, d2) -> (d2)>,
+/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
+/// ["parallel", "parallel", "parallel"]} {
+/// } -> tensor<112x112x16xf32>
+///
+/// into
+///
+/// %2 = linalg.generic {indexing_maps = [
+/// affine_map<(d0, d1) -> (d0, d1)>,
+/// affine_map<(d0, d1) -> (d1)>,
+/// affine_map<(d0, d1) -> (d0, d1)>],
+/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
+/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
+/// } -> tensor<12544x16xf32>
+/// %3 = linalg.tensor_reshape %2 [
+/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
+template <typename GenericOpTy>
+struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
+ using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOpTy op,
+ PatternRewriter &rewriter) const override {
+ // Only apply to elementwise linalg on tensor.
+ if (!op.hasTensorSemantics() ||
+ op.getNumParallelLoops() != op.getNumLoops())
+ return failure();
+ // Only support identity output maps. It could be extended to permuations if
+ // needed.
+ if (llvm::any_of(op.getOutputIndexingMaps(),
+ [](AffineMap map) { return !map.isIdentity(); }))
+ return failure();
+ int64_t destRank = op.getNumParallelLoops();
+ SmallVector<Value, 4> newOperands = llvm::to_vector<4>(op.getInputs());
+ TensorReshapeOp reshapeFound;
+ // 1. Look for tensor_reshape operands and figure out save the dimensions
+ // merged.
+ for (auto operand : llvm::enumerate(op.getInputs())) {
+ TensorReshapeOp reshapeOp =
+ operand.value().template getDefiningOp<TensorReshapeOp>();
+ if (!reshapeOp || reshapeOp.getSrcType().getRank() >
+ reshapeOp.getResultType().getRank()) {
+ continue;
+ }
+ // TODO: We could support non-identity map as long as the merged
+ // dimensions are still contiguous.
+ if (!op.getIndexingMaps()[operand.index()].isIdentity())
+ continue;
+ if (reshapeFound) {
+ // Only support a second reshape op if it has the same reassociate maps.
+ if (reshapeFound.getReassociationMaps() ==
+ reshapeOp.getReassociationMaps())
+ newOperands[operand.index()] = reshapeOp.src();
+ continue;
+ }
+ reshapeFound = reshapeOp;
+ newOperands[operand.index()] = reshapeOp.src();
+ }
+ if (!reshapeFound)
+ return failure();
+
+ // Calculate the reassociation indices and rassociated reverse map.
+ SmallVector<ReassociationIndices> reassociation =
+ getReassociationIndices(reshapeFound.getReassociationMaps());
+ SmallVector<unsigned, 4> remap(destRank);
+ for (auto &indices : llvm::enumerate(reassociation)) {
+ for (int64_t index : indices.value()) {
+ remap[index] = indices.index();
+ }
+ }
+ // 2. Verify that we can merge the dimensions in the linalg and that we
+ // don't need to create new reshapes operands. Inserting new reshape
+ // operands would defeat the purpose of the transformation.
+ for (auto operand : llvm::enumerate(op.getInputs())) {
+ if (operand.value() == newOperands[operand.index()]) {
+ AffineMap map = op.getIndexingMaps()[operand.index()];
+ for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
+ if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
+ return failure();
+ }
+ }
+ }
+
+ // 3. Calculate the affine map remapping and the reassociation to apply to
+ // output tensors.
+ SmallVector<AffineMap, 4> newMaps;
+ unsigned newRank = reassociation.size();
+ for (auto map : op.getIndexingMaps()) {
+ SmallVector<AffineExpr> newExprs;
+ for (auto expr : map.getResults()) {
+ unsigned position = expr.template cast<AffineDimExpr>().getPosition();
+ // Skip dimension merged except for the last of the group.
+ if (reassociation[remap[position]].back() == position) {
+ newExprs.push_back(
+ getAffineDimExpr(remap[position], op.getContext()));
+ }
+ }
+ newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext()));
+ }
+
+ // 4. Reshape the output tensors.
+ SmallVector<Value> newOutputs;
+ SmallVector<Type> newOutputTypes;
+ for (auto output : op.outputs()) {
+ Value newOutput = rewriter.create<TensorReshapeOp>(
+ op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
+ newOutputTypes.push_back(newOutput.getType());
+ newOutputs.push_back(newOutput);
+ }
+ // 5. Create a new generic op with lowerer rank.
+ SmallVector<StringRef, 4> iteratorTypes(newRank,
+ getParallelIteratorTypeName());
+ auto newOp =
+ rewriter.create<GenericOpTy>(op->getLoc(), newOutputTypes, newOperands,
+ newOutputs, newMaps, iteratorTypes);
+ rewriter.inlineRegionBefore(op.region(), newOp.region(),
+ newOp.region().begin());
+ // 6. Reshape the so that the type matches the uses.
+ SmallVector<Value> newResults;
+ for (auto result : llvm::enumerate(newOp->getResults())) {
+ newResults.push_back(rewriter.create<TensorReshapeOp>(
+ op->getLoc(), op.getOutputTensorTypes()[result.index()],
+ result.value(), reassociation));
+ }
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
/// Pattern to fuse a tensor_reshape op with its consumer
/// generic/indexed_generic op, when the reshape op is collapsing
/// dimensions. The dimensionality of the loop in the consumer is expanded.
@@ -1333,6 +1488,12 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
}
+void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
+ auto *context = patterns.getContext();
+ patterns.add<PushExpandingReshape<GenericOp>,
+ PushExpandingReshape<IndexedGenericOp>>(context);
+}
+
std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
return std::make_unique<FusionOfTensorOpsPass>();
}
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
new file mode 100644
index 0000000000000..bddf0d68749bc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func @reshape
+// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
+// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x112x16xf32> into tensor<?x16xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
+// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x16xf32> into tensor<?x112x16xf32>
+// CHECK: return %[[RR]] : tensor<?x112x16xf32>
+func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
+ %0 = linalg.tensor_reshape %A [
+ affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<?x16xf32> into tensor<?x112x16xf32>
+ %2 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>)
+ outs(%init : tensor<?x112x16xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
+ %s = subf %arg1, %arg2 : f32
+ linalg.yield %s : f32
+ } -> tensor<?x112x16xf32>
+ return %2 : tensor<?x112x16xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func @reshape_multiple
+// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
+// CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
+// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] [#[[$MAP0]], #[[$MAP1]]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
+// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: return %[[RR]] : tensor<112x112x16xf32>
+func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
+ %C: tensor<16xf32>) -> tensor<112x112x16xf32> {
+ %0 = linalg.tensor_reshape %A [
+ affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<12544x16xf32> into tensor<112x112x16xf32>
+ %1 = linalg.tensor_reshape %B [
+ affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<12544x16xf32> into tensor<112x112x16xf32>
+ %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
+ %3 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>)
+ outs(%2 : tensor<112x112x16xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
+ %s = subf %arg1, %arg2 : f32
+ %m = mulf %s, %arg3 : f32
+ linalg.yield %m : f32
+ } -> tensor<112x112x16xf32>
+ return %3 : tensor<112x112x16xf32>
+}
+
+// -----
+
+// Negative test, since the second source is broadcasted from d1 we cannot merge
+// d0 and d1 dimensions
+// CHECK-LABEL: func @reshape_negative
+// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: linalg.generic
+// CHECK: } -> tensor<112x112x16xf32>
+func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
+ %20 = linalg.tensor_reshape %A [
+ affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<12544x16xf32> into tensor<112x112x16xf32>
+ %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
+ %22 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>)
+ outs(%21 : tensor<112x112x16xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
+ %s = subf %arg1, %arg2 : f32
+ linalg.yield %s : f32
+ } -> tensor<112x112x16xf32>
+ return %22 : tensor<112x112x16xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
index bfd7344708989..d0812ab8ec0dd 100644
--- a/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
@@ -66,6 +66,22 @@ struct TestLinalgElementwiseFusion
std::move(fusionPatterns));
}
};
+
+struct TestPushExpandingReshape
+ : public PassWrapper<TestPushExpandingReshape, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+
+ void runOnFunction() override {
+ MLIRContext *context = &this->getContext();
+ FuncOp funcOp = this->getFunction();
+ RewritePatternSet patterns(context);
+ linalg::populatePushReshapeOpsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ }
+};
} // namespace
namespace test {
@@ -74,6 +90,11 @@ void registerTestLinalgElementwiseFusion() {
"test-linalg-elementwise-fusion-patterns",
"Test Linalg element wise operation fusion patterns");
}
+
+void registerTestPushExpandingReshape() {
+ PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
+ "test-linalg-push-reshape", "Test Linalg reshape push patterns");
+}
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eea5d8f494221..009e12eb9174b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -78,6 +78,7 @@ void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgElementwiseFusion();
+void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgGreedyFusion();
@@ -156,6 +157,7 @@ void registerTestPasses() {
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
test::registerTestLinalgElementwiseFusion();
+ test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgTensorFusionTransforms();
test::registerTestLinalgGreedyFusion();
More information about the Mlir-commits
mailing list