[Mlir-commits] [mlir] 01055ed - [mlir][linalg] Move linalg.fill folding into linalg.generic pattern from canonicalization to elementwise fusion
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 5 13:13:10 PDT 2022
Author: Nirvedh
Date: 2022-04-05T20:13:03Z
New Revision: 01055ed1d72dd74d0dcdf29d2cb734704ad673cb
URL: https://github.com/llvm/llvm-project/commit/01055ed1d72dd74d0dcdf29d2cb734704ad673cb
DIFF: https://github.com/llvm/llvm-project/commit/01055ed1d72dd74d0dcdf29d2cb734704ad673cb.diff
LOG: [mlir][linalg] Move linalg.fill folding into linalg.generic pattern from canonicalization to elementwise fusion
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D122847
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c72b52b4e1f08..a701cb6016dd3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -913,35 +913,12 @@ struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
return success();
}
};
-
-/// Fold linalg.fill into linalg.generic
-struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
- return failure();
- bool fillFound = false;
- Block &payload = genericOp.region().front();
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
- FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
- if (fillOp) {
- fillFound = true;
- payload.getArgument(opOperand->getOperandNumber())
- .replaceAllUsesWith(fillOp.value());
- }
- }
- // fail if there are no FillOps to fold.
- return success(fillFound);
- }
-};
} // namespace
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
- DeadArgsGenericOpInputs, FoldFillWithGenericOp>(context);
+ DeadArgsGenericOpInputs>(context);
}
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4effa26cb75cf..a8626bbc5b0fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2215,8 +2215,31 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
return success();
}
};
-} // namespace
+/// Fold linalg.fill into linalg.generic
+struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ bool fillFound = false;
+ Block &payload = genericOp.region().front();
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ if (!genericOp.payloadUsesValueFromOperand(opOperand))
+ continue;
+ FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
+ if (!fillOp)
+ continue;
+ fillFound = true;
+ payload.getArgument(opOperand->getOperandNumber())
+ .replaceAllUsesWith(fillOp.value());
+ }
+ return success(fillFound);
+ }
+};
+} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
@@ -2261,7 +2284,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
FoldConstantTranspose>(context,
options.controlElementwiseOpsFusionFn);
- patterns.add<RemoveOutsDependency>(context);
+ patterns.add<RemoveOutsDependency, FoldFillWithGenericOp>(context);
populateSparseTensorRewriting(patterns);
populateFoldReshapeOpsByExpansionPatterns(patterns,
options.controlFoldingReshapesFn);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 56ce26778c971..ecc3bc5b696ba 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -343,59 +343,6 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
// -----
-// CHECK-LABEL: func @fold_fill_generic_basic
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK-NOT: linalg.fill
-// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
-#map0 = affine_map<(d0) -> (d0)>
-func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 7.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
- %1 = linalg.init_tensor [%0] : tensor<?xf32>
- %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
- %3 = linalg.init_tensor [%0] : tensor<?xf32>
- %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
- %5 = arith.addf %arg1, %arg2 : f32
- linalg.yield %5 : f32
- } -> tensor<?xf32>
- return %4 : tensor<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @fold_fill_generic_mixedaccess
-// CHECK-NOT: linalg.fill
-// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-// CHECK-NOT: ins
-// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d1, d0)>
-func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 0 : index
- %cst1 = arith.constant 7.0 : f32
- %cst2 = arith.constant 6.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
- %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
- %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
- %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
- ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
- %8 = arith.divf %arg1, %arg2 : f32
- linalg.yield %8 : f32
- } -> tensor<?x?xf32>
- return %7 : tensor<?x?xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @remove_deadargs_generic_basic
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 0572ecd98ffc7..868b6e5f3a7d6 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -975,3 +975,56 @@ func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tens
// CHECK: %[[PRODUCER:.+]] = linalg.generic
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[PRODUCER]]
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_generic_basic
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK-NOT: linalg.fill
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 7.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %1 = linalg.init_tensor [%0] : tensor<?xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
+ %3 = linalg.init_tensor [%0] : tensor<?xf32>
+ %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %5 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?xf32>
+ return %4 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_generic_mixedaccess
+// CHECK-NOT: linalg.fill
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-NOT: ins
+// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+ %cst1 = arith.constant 7.0 : f32
+ %cst2 = arith.constant 6.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
+ %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %8 = arith.divf %arg1, %arg2 : f32
+ linalg.yield %8 : f32
+ } -> tensor<?x?xf32>
+ return %7 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list