[Mlir-commits] [mlir] 01055ed - [mlir][linalg] Move linalg.fill folding into linalg.generic pattern from canonicalization to elementwise fusion

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 5 13:13:10 PDT 2022


Author: Nirvedh
Date: 2022-04-05T20:13:03Z
New Revision: 01055ed1d72dd74d0dcdf29d2cb734704ad673cb

URL: https://github.com/llvm/llvm-project/commit/01055ed1d72dd74d0dcdf29d2cb734704ad673cb
DIFF: https://github.com/llvm/llvm-project/commit/01055ed1d72dd74d0dcdf29d2cb734704ad673cb.diff

LOG: [mlir][linalg] Move linalg.fill folding into linalg.generic pattern from canonicalization to elementwise fusion

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D122847

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c72b52b4e1f08..a701cb6016dd3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -913,35 +913,12 @@ struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
-
-/// Fold linalg.fill into linalg.generic
-struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
-      return failure();
-    bool fillFound = false;
-    Block &payload = genericOp.region().front();
-    for (OpOperand *opOperand : genericOp.getInputOperands()) {
-      FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
-      if (fillOp) {
-        fillFound = true;
-        payload.getArgument(opOperand->getOperandNumber())
-            .replaceAllUsesWith(fillOp.value());
-      }
-    }
-    // fail if there are no FillOps to fold.
-    return success(fillFound);
-  }
-};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
-              DeadArgsGenericOpInputs, FoldFillWithGenericOp>(context);
+              DeadArgsGenericOpInputs>(context);
 }
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4effa26cb75cf..a8626bbc5b0fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2215,8 +2215,31 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
-} // namespace
 
+/// Fold linalg.fill into linalg.generic
+struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+    bool fillFound = false;
+    Block &payload = genericOp.region().front();
+    for (OpOperand *opOperand : genericOp.getInputOperands()) {
+      if (!genericOp.payloadUsesValueFromOperand(opOperand))
+        continue;
+      FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
+      if (!fillOp)
+        continue;
+      fillFound = true;
+      payload.getArgument(opOperand->getOperandNumber())
+          .replaceAllUsesWith(fillOp.value());
+    }
+    return success(fillFound);
+  }
+};
+} // namespace
 //===---------------------------------------------------------------------===//
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
@@ -2261,7 +2284,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
                FoldConstantTranspose>(context,
                                       options.controlElementwiseOpsFusionFn);
-  patterns.add<RemoveOutsDependency>(context);
+  patterns.add<RemoveOutsDependency, FoldFillWithGenericOp>(context);
   populateSparseTensorRewriting(patterns);
   populateFoldReshapeOpsByExpansionPatterns(patterns,
                                             options.controlFoldingReshapesFn);

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 56ce26778c971..ecc3bc5b696ba 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -343,59 +343,6 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @fold_fill_generic_basic
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { 
-//   CHECK-NOT: linalg.fill
-//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
-//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
-#map0 = affine_map<(d0) -> (d0)>
-func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 7.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
-  %1 = linalg.init_tensor [%0] : tensor<?xf32>
-  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
-  %3 = linalg.init_tensor [%0] : tensor<?xf32>
-  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
-    %5 = arith.addf  %arg1, %arg2 : f32
-	linalg.yield %5 : f32
-  } -> tensor<?xf32>
-  return %4 : tensor<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @fold_fill_generic_mixedaccess
-//   CHECK-NOT: linalg.fill
-//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-//   CHECK-NOT: ins
-//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d1, d0)>
-func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 0 : index
-  %cst1 = arith.constant 7.0 : f32
-  %cst2 = arith.constant 6.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
-  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
-  %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
-  %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
-    %8 = arith.divf  %arg1, %arg2 : f32
-	linalg.yield %8 : f32
-  } -> tensor<?x?xf32>
-  return %7 : tensor<?x?xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @remove_deadargs_generic_basic
 //  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { 
 //       CHECK: %[[GENERIC_OP:.*]] = linalg.generic

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 0572ecd98ffc7..868b6e5f3a7d6 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -975,3 +975,56 @@ func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tens
 //       CHECK:   %[[PRODUCER:.+]] = linalg.generic
 //       CHECK:    linalg.generic
 //   CHECK-SAME:       ins(%[[PRODUCER]]
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_generic_basic
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { 
+//   CHECK-NOT: linalg.fill
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 7.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %1 = linalg.init_tensor [%0] : tensor<?xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
+  %3 = linalg.init_tensor [%0] : tensor<?xf32>
+  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %5 = arith.addf  %arg1, %arg2 : f32
+	linalg.yield %5 : f32
+  } -> tensor<?xf32>
+  return %4 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_generic_mixedaccess
+//   CHECK-NOT: linalg.fill
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//   CHECK-NOT: ins
+//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 0 : index
+  %cst1 = arith.constant 7.0 : f32
+  %cst2 = arith.constant 6.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
+  %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %8 = arith.divf  %arg1, %arg2 : f32
+	linalg.yield %8 : f32
+  } -> tensor<?x?xf32>
+  return %7 : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list