[Mlir-commits] [mlir] e3f75c1 - [mlir][linalg] Allow some fusion on mixed generics
Ivan Butygin
llvmlistbot at llvm.org
Tue Nov 29 06:36:10 PST 2022
Author: Ivan Butygin
Date: 2022-11-29T15:35:02+01:00
New Revision: e3f75c1cb78123259b027ce5c82533ff7013c1b3
URL: https://github.com/llvm/llvm-project/commit/e3f75c1cb78123259b027ce5c82533ff7013c1b3
DIFF: https://github.com/llvm/llvm-project/commit/e3f75c1cb78123259b027ce5c82533ff7013c1b3.diff
LOG: [mlir][linalg] Allow some fusion on mixed generics
Relax linalg elementwise fusion check to allow mixed consumers. Producer is still required to be fully tensor to avoid potential memref aliasing.
Differential Revision: https://reviews.llvm.org/D138759
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index d0efba5a98938..a127bd2dc652b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -55,7 +55,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (genericOp.hasBufferSemantics())
+ // Mixed and buffer sematics aren't supported.
+ if (!genericOp.hasTensorSemantics())
return failure();
// Only support ops generating one output for now.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 26636bc053653..f391b2cc01c91 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -79,8 +79,11 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!producer || !consumer)
return false;
- // Producer and consumer must have tensor semantics.
- if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
+ // Consumer can have mixed semantics, just check operand itself has tensor
+ // type. Producer must have full tensor semantics to avoid potential
+ // aliasing between producer and consumer memrefs.
+ if (!producer.hasTensorSemantics() ||
+ !fusedOperand->get().getType().isa<RankedTensorType>())
return false;
// Verify that
@@ -348,7 +351,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
for (OpOperand *opOperand : consumer.getDpsInitOperands()) {
fusedOutputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
- fusedResultTypes.push_back(opOperand->get().getType());
+ Type resultType = opOperand->get().getType();
+ if (!resultType.isa<MemRefType>())
+ fusedResultTypes.push_back(resultType);
}
// Generate the fused op.
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index b334eedf61ef6..aff6a8fc1925e 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -54,15 +54,6 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
- // Simplifying assumption: either full tensor or full buffer mode.
- // This allows simpler verification of output operands vs result types
- // without premature tracking of which operand is what in mixed-mode.
- // TODO: relax when mixed-mode needs to pass verification.
- if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
- return op->emitOpError(
- "expected output operands to all have tensor type or "
- "all have buffer type");
-
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 2950b27b290b7..5bcae368b83b0 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1110,3 +1110,43 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]
+
+// -----
+
+// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @mixed_fusion
+func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %4 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ // CHECK: linalg.generic {
+ // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
+ linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg8 : memref<?x?xf32>) {
+ // CHECK: ^{{[a-zA-Z0-9_]*}}
+ // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
+ // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
+ // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
+ // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
+ // CHECK-NOT: linalg.yield
+ // CHECK: arith.mulf [[T1]], [[ARG2]]
+ // CHECK: linalg.yield
+ %5 = arith.mulf %arg5, %arg6 : f32
+ linalg.yield %5 : f32
+ }
+ return
+}
More information about the Mlir-commits
mailing list