[Mlir-commits] [mlir] [mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (PR #84415)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 7 17:25:36 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Make them more general instead of only supporting `func::FuncOp`.
---
Full diff: https://github.com/llvm/llvm-project/pull/84415.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+2-2)
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+3-2)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+10-8)
- (modified) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (+1-2)
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+1-1)
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+3-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 7fb6631574b410..06ebf151e7d649 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
// ShardingPropagation
//===----------------------------------------------------------------------===//
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
let summary = "sharding propagation";
let description = [{
Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
-def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
This pass fits in right after a pass that annotates the function with
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9f2647b21cbfc8..29320f1e339f86 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
void runOnOperation() override {
- func::FuncOp funcOp = getOperation();
+ FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
- Region ®ion = funcOp.getBody();
+ Region ®ion = funcOp.getFunctionBody();
OpBuilder builder(ctx);
if (!region.hasOneBlock()) {
funcOp.emitOpError() << "only one block is supported!";
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c4d8b0b15e462c..e4868435135ed1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -24,6 +24,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
}
static LogicalResult
-spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection) {
OpBuilder builder(op.getFunctionBody());
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
// Find a return op and change the function results signature to its operands
// signature.
- func::ReturnOp returnOp;
- for (Block &block : op.getBody()) {
+ Operation *returnOp = nullptr;
+ for (Block &block : op.getFunctionBody()) {
if (block.empty()) {
continue;
}
- returnOp = llvm::cast<func::ReturnOp>(block.back());
- if (returnOp) {
+ if (block.back().hasTrait<OpTrait::ReturnLike>()) {
+ returnOp = &block.back();
break;
}
}
assert(returnOp);
- op.setFunctionType(FunctionType::get(op->getContext(),
- op.getBody().front().getArgumentTypes(),
- returnOp->getOperandTypes()));
+ op.setType(FunctionType::get(op->getContext(),
+ op.getFunctionBody().front().getArgumentTypes(),
+ returnOp->getOperandTypes()));
return success();
}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 6d21def8de2753..bd56c801283b17 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,6 +1,5 @@
// RUN: mlir-opt \
-// RUN: --mesh-spmdization \
-// RUN: --test-constant-fold \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 94f8d94073c5ef..270787ab518831 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
mesh.mesh @mesh_1d(shape = ?)
mesh.mesh @mesh_2d(shape = 2x4)
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 572d3eb55eaaae..2df247aba35155 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN: %s | FileCheck %s
mesh.mesh @mesh_1d(shape = 2)
``````````
</details>
https://github.com/llvm/llvm-project/pull/84415
More information about the Mlir-commits
mailing list