[Mlir-commits] [mlir] 7568f71 - Revert "[mlir][Linalg] NFC: Combine elementwise fusion test passes."

Mahesh Ravishankar llvmlistbot at llvm.org
Mon Feb 7 14:51:36 PST 2022


Author: Mahesh Ravishankar
Date: 2022-02-07T22:51:29Z
New Revision: 7568f7101f8881d7a8ea830cf7c2b9d8cfce4ac9

URL: https://github.com/llvm/llvm-project/commit/7568f7101f8881d7a8ea830cf7c2b9d8cfce4ac9
DIFF: https://github.com/llvm/llvm-project/commit/7568f7101f8881d7a8ea830cf7c2b9d8cfce4ac9.diff

LOG: Revert "[mlir][Linalg] NFC: Combine elementwise fusion test passes."

This reverts commit d730336411b59622a625510378cec0f9d23807c6.

Added: 
    

Modified: 
    mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
    mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
    mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
index 103a04d79ba4a..d81aab66491a4 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #binary2Dpointwise = {

diff  --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index a1d428865120f..0c02ff8c54d1f 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
 
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>

diff  --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
index c4e7d5552678e..d9e440c96efd8 100644
--- a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
 
 func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 85df3252fb56f..30bef4af8bcc3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -58,77 +58,87 @@ struct TestLinalgElementwiseFusion
     return "Test Linalg element wise operation fusion patterns";
   }
 
-  Option<bool>
-      fuseGenericOps(*this, "fuse-generic-ops",
-                     llvm::cl::desc("Test fusion of generic operations."),
-                     llvm::cl::init(false));
-
-  Option<bool> controlFuseByExpansion(
-      *this, "control-fusion-by-expansion",
-      llvm::cl::desc(
-          "Test controlling fusion of reshape with generic op by expansion"),
-      llvm::cl::init(false));
-
-  Option<bool>
-      pushExpandingReshape(*this, "push-expanding-reshape",
-                           llvm::cl::desc("Test linalg expand_shape -> generic "
-                                          "to generic -> expand_shape pattern"),
-                           llvm::cl::init(false));
-
   void runOnOperation() override {
     MLIRContext *context = &this->getContext();
     FuncOp funcOp = this->getOperation();
+    RewritePatternSet fusionPatterns(context);
+
+    linalg::populateElementwiseOpsFusionPatterns(
+        fusionPatterns,
+        linalg::LinalgElementwiseFusionOptions()
+            .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
 
-    if (fuseGenericOps) {
-      RewritePatternSet fusionPatterns(context);
-      linalg::populateElementwiseOpsFusionPatterns(
-          fusionPatterns,
-          linalg::LinalgElementwiseFusionOptions()
-              .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
-
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                         std::move(fusionPatterns));
-      return;
-    }
-
-    if (controlFuseByExpansion) {
-      RewritePatternSet fusionPatterns(context);
-
-      linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
-          [](const OpResult &producer, OpOperand &consumer) {
-            if (auto collapseOp =
-                    producer.getDefiningOp<tensor::CollapseShapeOp>()) {
-              if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
-                return false;
-              }
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                       std::move(fusionPatterns));
+  }
+};
+
+struct TestLinalgControlFuseByExpansion
+    : public PassWrapper<TestLinalgControlFuseByExpansion,
+                         OperationPass<FuncOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-linalg-control-fusion-by-expansion";
+  }
+  StringRef getDescription() const final {
+    return "Test controlling of fusion of elementwise ops with reshape by "
+           "expansion";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    FuncOp funcOp = this->getOperation();
+    RewritePatternSet fusionPatterns(context);
+
+    linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
+        [](const OpResult &producer, OpOperand &consumer) {
+          if (auto collapseOp =
+                  producer.getDefiningOp<tensor::CollapseShapeOp>()) {
+            if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
+              return false;
             }
-            if (auto expandOp =
-                    dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
-              if (expandOp->hasOneUse()) {
-                OpOperand &use = *expandOp->getUses().begin();
-                auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
-                if (linalgOp && linalgOp.isOutputTensor(&use))
-                  return true;
-              }
+          }
+          if (auto expandOp =
+                  dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
+            if (expandOp->hasOneUse()) {
+              OpOperand &use = *expandOp->getUses().begin();
+              auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
+              if (linalgOp && linalgOp.isOutputTensor(&use))
+                return true;
             }
-            return linalg::skipUnitDimReshape(producer, consumer);
-          };
-
-      linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
-                                                        controlReshapeFusionFn);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                         std::move(fusionPatterns));
-      return;
-    }
-
-    if (pushExpandingReshape) {
-      RewritePatternSet patterns(context);
-      linalg::populatePushReshapeOpsPatterns(patterns);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
-    }
+          }
+          return linalg::skipUnitDimReshape(producer, consumer);
+        };
+
+    linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
+                                                      controlReshapeFusionFn);
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                       std::move(fusionPatterns));
   }
 };
 
+struct TestPushExpandingReshape
+    : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final { return "test-linalg-push-reshape"; }
+  StringRef getDescription() const final {
+    return "Test Linalg reshape push patterns";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    FuncOp funcOp = this->getOperation();
+    RewritePatternSet patterns(context);
+    linalg::populatePushReshapeOpsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace test {

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 73d1b54bbf4fd..5b09cb8671eb1 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -81,8 +81,10 @@ void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
+void registerTestLinalgControlFuseByExpansion();
 void registerTestLinalgDistribution();
 void registerTestLinalgElementwiseFusion();
+void registerTestPushExpandingReshape();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
 void registerTestLinalgTiledLoopFusionTransforms();
@@ -170,8 +172,10 @@ void registerTestPasses() {
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLinalgCodegenStrategy();
+  mlir::test::registerTestLinalgControlFuseByExpansion();
   mlir::test::registerTestLinalgDistribution();
   mlir::test::registerTestLinalgElementwiseFusion();
+  mlir::test::registerTestPushExpandingReshape();
   mlir::test::registerTestLinalgFusionTransforms();
   mlir::test::registerTestLinalgTensorFusionTransforms();
   mlir::test::registerTestLinalgTiledLoopFusionTransforms();


        


More information about the Mlir-commits mailing list