[Mlir-commits] [mlir] 071358e - [mlir][Linalg] Add producer-consumer fusion when producer is a ConstantOp
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 20 09:16:51 PDT 2020
Author: MaheshRavishankar
Date: 2020-05-20T09:16:19-07:00
New Revision: 071358e08224b9971f6b7fc49a5e014a5662187c
URL: https://github.com/llvm/llvm-project/commit/071358e08224b9971f6b7fc49a5e014a5662187c
DIFF: https://github.com/llvm/llvm-project/commit/071358e08224b9971f6b7fc49a5e014a5662187c.diff
LOG: [mlir][Linalg] Add producer-consumer fusion when producer is a ConstantOp
and Consumer is a GenericOp.
Differential Revision: https://reviews.llvm.org/D79838
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index fbeac83d4305..3123f95452fd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -766,6 +766,66 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
return fusedOp;
}
};
+
+/// Implementation of fusion on tensor ops when producer is a splat constant.
+template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
+ static bool isFusible(ConstantOp producer, LinalgOpTy consumer,
+ unsigned consumerIdx) {
+ return producer.getResult().getType().isa<RankedTensorType>() &&
+ producer.value().template cast<DenseElementsAttr>().isSplat();
+ }
+
+ static Operation *fuse(ConstantOp producer, LinalgOpTy consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (!isFusible(producer, consumer, consumerIdx))
+ return nullptr;
+
+ // The indexing_maps for the operands of the fused operation are same as
+ // those for the operands of the consumer without the indexing map at
+ // consumerIdx
+ SmallVector<AffineMap, 4> fusedIndexMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+ fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
+
+ // The operands list is same as the consumer with the argument for constant
+ // index dropped.
+ SmallVector<Value, 4> fusedOperands(consumer.operand_begin(),
+ consumer.operand_end());
+ fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
+
+ // Create a constant scalar value from the splat constant.
+ Value scalarConstant = rewriter.create<ConstantOp>(
+ producer.getLoc(),
+ producer.value().template cast<DenseElementsAttr>().getSplatValue());
+
+ auto fusedOp = rewriter.create<LinalgOpTy>(
+ rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1),
+ rewriter.getI64IntegerAttr(consumer.getNumResults()),
+ rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr);
+
+ // Map the block argument corresponding to the replaced argument with the
+ // scalar constant.
+ Region &consumerRegion = consumer.region();
+ Block &entryBlock = *consumerRegion.begin();
+ unsigned argIndex =
+ entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx;
+ BlockAndValueMapping mapping;
+ mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
+ Region &fusedRegion = fusedOp.region();
+ rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
+ mapping);
+ return fusedOp;
+ }
+};
+
} // namespace
Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
@@ -789,6 +849,9 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
} else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
+ } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
+ return FuseConstantOpAsProducer<GenericOp>::fuse(
+ constantOpProducer, genericOp, consumerIdx, rewriter, folder);
}
return nullptr;
}
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 2c00f77edd3f..83bd1753eb28 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -219,3 +219,58 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
// CHECK: linalg.tensor_reshape
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+{
+ %0 = constant dense<42.0> : tensor<5xf32>
+ %1 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ %0, %arg0 {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %2 = mulf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ return %1 : tensor<5x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @generic_op_constant_fusion
+// CHECK: %[[CST:.*]] = constant {{.*}} : f32
+// CHECK: linalg.generic
+// CHECK-SAME: args_in = 1 : i64
+// CHECK-SAME: args_out = 1 : i64
+// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
+// CHECK: mulf %[[CST]], %[[ARG1]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> ()>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
+ -> tensor<5x?x?xf32>
+{
+ %0 = constant dense<42.0> : tensor<f32>
+ %1 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ %0, %arg0 {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %2 = mulf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ }: tensor<f32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ return %1 : tensor<5x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @generic_op_zero_dim_constant_fusion
+// CHECK: %[[CST:.*]] = constant {{.*}} : f32
+// CHECK: linalg.generic
+// CHECK-SAME: args_in = 1 : i64
+// CHECK-SAME: args_out = 1 : i64
+// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
+// CHECK: mulf %[[CST]], %[[ARG1]]
More information about the Mlir-commits
mailing list