[Mlir-commits] [mlir] abfac56 - [mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 8 08:14:39 PST 2024


Author: Boian Petkantchin
Date: 2024-03-08T08:14:36-08:00
New Revision: abfac563f5b5a123e4bf773c3a09777e6fc4f50c

URL: https://github.com/llvm/llvm-project/commit/abfac563f5b5a123e4bf773c3a09777e6fc4f50c
DIFF: https://github.com/llvm/llvm-project/commit/abfac563f5b5a123e4bf773c3a09777e6fc4f50c.diff

LOG: [mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415)

Make them more general instead of only supporting `func::FuncOp`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
    mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
    mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
    mlir/test/Dialect/Linalg/mesh-spmdization.mlir
    mlir/test/Dialect/Mesh/sharding-propagation.mlir
    mlir/test/Dialect/Mesh/spmdization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 7fb6631574b410..06ebf151e7d649 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
 // ShardingPropagation
 //===----------------------------------------------------------------------===//
 
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
   let summary = "sharding propagation";
   let description = [{
     Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
-def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
   let summary = "Partition a function into SPMD form.";
   let description = [{
     This pass fits in right after a pass that annotates the function with

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9f2647b21cbfc8..29320f1e339f86 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 #include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
+    FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
-    Region &region = funcOp.getBody();
+    Region &region = funcOp.getFunctionBody();
     OpBuilder builder(ctx);
     if (!region.hasOneBlock()) {
       funcOp.emitOpError() << "only one block is supported!";

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c4d8b0b15e462c..e4868435135ed1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -24,6 +24,8 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
 }
 
 static LogicalResult
-spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
               SymbolTableCollection &symbolTableCollection) {
   OpBuilder builder(op.getFunctionBody());
 
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
 
   // Find a return op and change the function results signature to its operands
   // signature.
-  func::ReturnOp returnOp;
-  for (Block &block : op.getBody()) {
+  Operation *returnOp = nullptr;
+  for (Block &block : op.getFunctionBody()) {
     if (block.empty()) {
       continue;
     }
 
-    returnOp = llvm::cast<func::ReturnOp>(block.back());
-    if (returnOp) {
+    if (block.back().hasTrait<OpTrait::ReturnLike>()) {
+      returnOp = &block.back();
       break;
     }
   }
   assert(returnOp);
-  op.setFunctionType(FunctionType::get(op->getContext(),
-                                       op.getBody().front().getArgumentTypes(),
-                                       returnOp->getOperandTypes()));
+  op.setType(FunctionType::get(op->getContext(),
+                               op.getFunctionBody().front().getArgumentTypes(),
+                               returnOp->getOperandTypes()));
 
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 6d21def8de2753..bd56c801283b17 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,6 +1,5 @@
 // RUN: mlir-opt \
-// RUN:  --mesh-spmdization \
-// RUN:  --test-constant-fold \
+// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
 // RUN:  --split-input-file \
 // RUN:  %s | FileCheck %s
 

diff  --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 94f8d94073c5ef..270787ab518831 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = ?)
 mesh.mesh @mesh_2d(shape = 2x4)

diff  --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 572d3eb55eaaae..2df247aba35155 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:   %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = 2)
 


        


More information about the Mlir-commits mailing list