[Mlir-commits] [mlir] [mlir][mesh] Add folding of ClusterShapeOp (PR #77033)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 4 17:43:50 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
If the mesh has static size on some of the requested axes, the result is substituted with a constant.
---
Full diff: https://github.com/llvm/llvm-project/pull/77033.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+7)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+86)
- (added) mlir/test/Dialect/Mesh/folding.mlir (+22)
- (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+2-1)
- (added) mlir/test/lib/Dialect/Mesh/TestFolding.cpp (+52)
- (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1-1)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f70bdaa9de0a0f..f7096cfce634ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -19,6 +19,9 @@
#include <utility>
namespace mlir {
+
+class SymbolTable;
+
namespace mesh {
// If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}
void populateSimplificationPatterns(RewritePatternSet &patterns);
+// It is invalid to change ops that declare symbols during the application of
+// these patterns, because symbolTable is used to cache them.
+void populateFoldingPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTable);
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..eab3bc88fd1d38 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -8,6 +8,17 @@
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+#include <utility>
namespace mlir {
namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
// TODO: add simplifications for all-gather and other collectives.
}
+namespace {
+
+// This folding can not be done with an operation's fold method or
+// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
+// symbol tables.
+// We can't use DialectFoldInterface since the cache may be invalidated by some
+// pass changing the referenced ClusterOp ops.
+struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+ template <typename... OpRewritePatternArgs>
+ ClusterShapeFolder(SymbolTableCollection &symbolTable,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTable(symbolTable) {}
+ LogicalResult matchAndRewrite(ClusterShapeOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+ if (!mesh) {
+ return failure();
+ }
+ ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+ SmallVector<MeshAxis> opAxesIota;
+ if (opMeshAxes.empty()) {
+ opAxesIota.resize(mesh.getRank());
+ std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+ opMeshAxes = opAxesIota;
+ }
+ if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
+ return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+ })) {
+ // All mesh dimensions are dynamic. Nothing to fold.
+ return failure();
+ }
+
+ SmallVector<Value> newResults(op->getResults().size());
+ SmallVector<MeshAxis> newShapeOpMeshAxes;
+ SmallVector<size_t> newToOldResultsIndexMap;
+
+ for (size_t i = 0; i < opMeshAxes.size(); ++i) {
+ auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+ if (ShapedType::isDynamic(meshAxisSize)) {
+ newToOldResultsIndexMap.push_back(i);
+ newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+ } else {
+ // Fold static mesh axes.
+ newResults[i] = builder.create<arith::ConstantOp>(
+ builder.getIndexAttr(meshAxisSize));
+ }
+ }
+
+ // Leave only the dynamic mesh axes to be queried.
+ ClusterShapeOp newShapeOp =
+ builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+ for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+ newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+ }
+
+ rewriter.replaceAllUsesWith(op.getResults(), newResults);
+
+ return success();
+ }
+
+private:
+ SymbolTableCollection &symbolTable;
+};
+
+} // namespace
+
+void populateFoldingPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTable) {
+ patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+}
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
new file mode 100644
index 00000000000000..1283353709ca3c
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
+mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding
+func.func @cluster_shape_op_folding() -> (index, index) {
+ // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+ // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
+func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+ // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+ %0:2 = mesh.cluster_shape @mesh1 : index, index
+ // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..3da64694ee2155 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTestSimplifications
+add_mlir_library(MLIRMeshTest
+ TestFolding.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
new file mode 100644
index 00000000000000..1cf436edea8e35
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
@@ -0,0 +1,52 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+struct TestMeshFoldingPass
+ : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
+
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mesh::MeshDialect>();
+ }
+ StringRef getArgument() const final { return "test-mesh-folding"; }
+ StringRef getDescription() const final { return "Test mesh folding."; }
+};
+} // namespace
+
+void TestMeshFoldingPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTables;
+ mesh::populateFoldingPatterns(patterns, symbolTables);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+ getOperation()->emitError()
+ << "Rewrite patter application did not converge.";
+ return signalPassFailure();
+ }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b6ada66d321880..a5da9390a0c5b3 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
- MLIRMeshTestSimplifications
+ MLIRMeshTest
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f7a5b3183b50b1..461163f671ce89 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
+void registerTestMeshFoldingPass();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
@@ -237,6 +238,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
+ mlir::test::registerTestMeshFoldingPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
``````````
</details>
https://github.com/llvm/llvm-project/pull/77033
More information about the Mlir-commits
mailing list