[Mlir-commits] [mlir] [mlir][mesh] Add folding of ClusterShapeOp (PR #77033)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jan 9 10:21:32 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/77033

>From 3706d6fd8c26747883abb1d4c1081a4eb193fd20 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 3 Jan 2024 17:34:07 -0800
Subject: [PATCH 1/2] [mlir][mesh] Add folding of ClusterShapeOp

If the mesh has static size on some of the requested axes,
the result is substituted with a constant.
---
 .../Dialect/Mesh/Transforms/Simplifications.h |  7 ++
 .../Mesh/Transforms/Simplifications.cpp       | 86 +++++++++++++++++++
 mlir/test/Dialect/Mesh/folding.mlir           | 22 +++++
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |  3 +-
 mlir/test/lib/Dialect/Mesh/TestFolding.cpp    | 52 +++++++++++
 mlir/tools/mlir-opt/CMakeLists.txt            |  2 +-
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 7 files changed, 172 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/Mesh/folding.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestFolding.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f70bdaa9de0a0f..f7096cfce634ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -19,6 +19,9 @@
 #include <utility>
 
 namespace mlir {
+
+class SymbolTable;
+
 namespace mesh {
 
 // If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
 }
 
 void populateSimplificationPatterns(RewritePatternSet &patterns);
+// It is invalid to change ops that declare symbols during the application of
+// these patterns, because symbolTable is used to cache them.
+void populateFoldingPatterns(RewritePatternSet &patterns,
+                             SymbolTableCollection &symbolTable);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..eab3bc88fd1d38 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -8,6 +8,17 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+#include <utility>
 
 namespace mlir {
 namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
   // TODO: add simplifications for all-gather and other collectives.
 }
 
+namespace {
+
+// This folding can not be done with an operation's fold method or
+// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
+// symbol tables.
+// We can't use DialectFoldInterface since the cache may be invalidated by some
+// pass changing the referenced ClusterOp ops.
+struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+  template <typename... OpRewritePatternArgs>
+  ClusterShapeFolder(SymbolTableCollection &symbolTable,
+                     OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTable(symbolTable) {}
+  LogicalResult matchAndRewrite(ClusterShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+        op.getOperation(), op.getMeshAttr());
+    if (!mesh) {
+      return failure();
+    }
+    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+    SmallVector<MeshAxis> opAxesIota;
+    if (opMeshAxes.empty()) {
+      opAxesIota.resize(mesh.getRank());
+      std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+      opMeshAxes = opAxesIota;
+    }
+    if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
+          return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+        })) {
+      // All mesh dimensions are dynamic. Nothing to fold.
+      return failure();
+    }
+
+    SmallVector<Value> newResults(op->getResults().size());
+    SmallVector<MeshAxis> newShapeOpMeshAxes;
+    SmallVector<size_t> newToOldResultsIndexMap;
+
+    for (size_t i = 0; i < opMeshAxes.size(); ++i) {
+      auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+      if (ShapedType::isDynamic(meshAxisSize)) {
+        newToOldResultsIndexMap.push_back(i);
+        newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+      } else {
+        // Fold static mesh axes.
+        newResults[i] = builder.create<arith::ConstantOp>(
+            builder.getIndexAttr(meshAxisSize));
+      }
+    }
+
+    // Leave only the dynamic mesh axes to be queried.
+    ClusterShapeOp newShapeOp =
+        builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+    for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+      newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+    }
+
+    rewriter.replaceAllUsesWith(op.getResults(), newResults);
+
+    return success();
+  }
+
+private:
+  SymbolTableCollection &symbolTable;
+};
+
+} // namespace
+
+void populateFoldingPatterns(RewritePatternSet &patterns,
+                             SymbolTableCollection &symbolTable) {
+  patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+}
+
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
new file mode 100644
index 00000000000000..1283353709ca3c
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
+mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding
+func.func @cluster_shape_op_folding() -> (index, index) {
+  // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
+  %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+  // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
+func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+  // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+  %0:2 = mesh.cluster_shape @mesh1 : index, index
+  // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..3da64694ee2155 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTestSimplifications
+add_mlir_library(MLIRMeshTest
+  TestFolding.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
new file mode 100644
index 00000000000000..1cf436edea8e35
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
@@ -0,0 +1,52 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+struct TestMeshFoldingPass
+    : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mesh::MeshDialect>();
+  }
+  StringRef getArgument() const final { return "test-mesh-folding"; }
+  StringRef getDescription() const final { return "Test mesh folding."; }
+};
+} // namespace
+
+void TestMeshFoldingPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  SymbolTableCollection symbolTables;
+  mesh::populateFoldingPatterns(patterns, symbolTables);
+  if (failed(
+          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+    getOperation()->emitError()
+        << "Rewrite patter application did not converge.";
+    return signalPassFailure();
+  }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index ce2f5bf4094a5a..9ad5b32c24f9de 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLoopLikeInterfaceTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
-    MLIRMeshTestSimplifications
+    MLIRMeshTest
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 09ff66e07957af..799fc58067d4fb 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestMeshFoldingPass();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
 void registerTestNextAccessPass();
@@ -240,6 +241,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMeshFoldingPass();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();

>From aad27916c229931541680d6e4d4ed20c5c7c6464 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 5 Jan 2024 07:10:53 -0800
Subject: [PATCH 2/2] Add folding patterns to all simplification patterns

---
 .../Dialect/Mesh/Transforms/Simplifications.h |  9 ++--
 .../Mesh/Transforms/Simplifications.cpp       | 21 +++++---
 mlir/test/Dialect/Mesh/folding.mlir           |  2 +-
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |  1 -
 mlir/test/lib/Dialect/Mesh/TestFolding.cpp    | 52 -------------------
 .../lib/Dialect/Mesh/TestSimplifications.cpp  |  8 ++-
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 -
 7 files changed, 25 insertions(+), 70 deletions(-)
 delete mode 100644 mlir/test/lib/Dialect/Mesh/TestFolding.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f7096cfce634ee..f438465251bb06 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -20,7 +20,7 @@
 
 namespace mlir {
 
-class SymbolTable;
+class SymbolTableCollection;
 
 namespace mesh {
 
@@ -105,11 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
       AlgebraicOp::getOperationName(), 1, patterns.getContext()));
 }
 
-void populateSimplificationPatterns(RewritePatternSet &patterns);
 // It is invalid to change ops that declare symbols during the application of
-// these patterns, because symbolTable is used to cache them.
+// these patterns, because symbolTableCollection is used to cache them.
+void populateSimplificationPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
 void populateFoldingPatterns(RewritePatternSet &patterns,
-                             SymbolTableCollection &symbolTable);
+                             SymbolTableCollection &symbolTableCollection);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index eab3bc88fd1d38..6262d3aa162654 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -23,7 +23,8 @@
 namespace mlir {
 namespace mesh {
 
-void populateSimplificationPatterns(RewritePatternSet &patterns) {
+void populateSimplificationPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
   populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
       patterns, Partial::Sum);
   populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
@@ -44,6 +45,8 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
       patterns, Partial::Max);
 
   // TODO: add simplifications for all-gather and other collectives.
+
+  populateFoldingPatterns(patterns, symbolTableCollection);
 }
 
 namespace {
@@ -55,16 +58,17 @@ namespace {
 // pass changing the referenced ClusterOp ops.
 struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
   template <typename... OpRewritePatternArgs>
-  ClusterShapeFolder(SymbolTableCollection &symbolTable,
+  ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
                      OpRewritePatternArgs &&...opRewritePatternArgs)
       : OpRewritePattern(
             std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTable(symbolTable) {}
+        symbolTableCollection(symbolTableCollection) {}
   LogicalResult matchAndRewrite(ClusterShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-    ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
-        op.getOperation(), op.getMeshAttr());
+    ClusterOp mesh =
+        symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+            op.getOperation(), op.getMeshAttr());
     if (!mesh) {
       return failure();
     }
@@ -111,14 +115,15 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
   }
 
 private:
-  SymbolTableCollection &symbolTable;
+  SymbolTableCollection &symbolTableCollection;
 };
 
 } // namespace
 
 void populateFoldingPatterns(RewritePatternSet &patterns,
-                             SymbolTableCollection &symbolTable) {
-  patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+                             SymbolTableCollection &symbolTableCollection) {
+  patterns.add<ClusterShapeFolder>(symbolTableCollection,
+                                   patterns.getContext());
 }
 
 } // namespace mesh
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index 1283353709ca3c..dd64d746341b83 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
 mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
 mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 3da64694ee2155..daff88235b5bde 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,6 +1,5 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMeshTest
-  TestFolding.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
deleted file mode 100644
index 1cf436edea8e35..00000000000000
--- a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
+++ /dev/null
@@ -1,52 +0,0 @@
-//===- TestSimplification.cpp - Test simplification -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <memory>
-
-using namespace mlir;
-
-namespace {
-
-struct TestMeshFoldingPass
-    : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
-
-  void runOnOperation() override;
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<mesh::MeshDialect>();
-  }
-  StringRef getArgument() const final { return "test-mesh-folding"; }
-  StringRef getDescription() const final { return "Test mesh folding."; }
-};
-} // namespace
-
-void TestMeshFoldingPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  SymbolTableCollection symbolTables;
-  mesh::populateFoldingPatterns(patterns, symbolTables);
-  if (failed(
-          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
-    getOperation()->emitError()
-        << "Rewrite patter application did not converge.";
-    return signalPassFailure();
-  }
-}
-
-namespace mlir {
-namespace test {
-void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
index 93b1da52d46b4e..12a5fd532c4c96 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -30,8 +31,11 @@ struct TestMeshSimplificationsPass
 
 void TestMeshSimplificationsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  mesh::populateSimplificationPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  SymbolTableCollection symbolTableCollection;
+  mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
+  LogicalResult status =
+      applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  assert(succeeded(status) && "Rewrite patters application did not converge.");
 }
 
 namespace mlir {
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 799fc58067d4fb..09ff66e07957af 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,7 +118,6 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
-void registerTestMeshFoldingPass();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
 void registerTestNextAccessPass();
@@ -241,7 +240,6 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
-  mlir::test::registerTestMeshFoldingPass();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();



More information about the Mlir-commits mailing list