[Mlir-commits] [mlir] d12019d - [mlir][shard, mpi] lowering shard.all_slice in shard-to-mpi (#176438)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 06:37:39 PST 2026


Author: Frank Schlimbach
Date: 2026-01-19T15:37:35+01:00
New Revision: d12019d5e52fd36651f29d5485875e847990172d

URL: https://github.com/llvm/llvm-project/commit/d12019d5e52fd36651f29d5485875e847990172d
DIFF: https://github.com/llvm/llvm-project/commit/d12019d5e52fd36651f29d5485875e847990172d.diff

LOG: [mlir][shard,mpi] lowering shard.all_slice in shard-to-mpi (#176438)

Lowering shard.all_slice in shard-to-mpi and reusing lowering for
shard.processmultindex.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
    mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
    mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
    mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..1096338534416 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1001,10 +1001,12 @@ def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
   let dependentDialects = [
     "affine::AffineDialect",
     "arith::ArithDialect",
+    "bufferization::BufferizationDialect",
+    "cf::ControlFlowDialect",
     "memref::MemRefDialect",
     "mpi::MPIDialect",
     "scf::SCFDialect",
-    "bufferization::BufferizationDialect"
+    "tensor::TensorDialect"
   ];
 }
 

diff  --git a/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
index 564f36fd20abb..dfa8e77bb13ef 100644
--- a/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
@@ -15,6 +15,11 @@ add_mlir_conversion_library(MLIRShardToMPI
   MLIRFuncDialect
   MLIRIR
   MLIRLinalgTransforms
+  MLIRAffineDialect
+  MLIRArithDialect
+  MLIRControlFlowDialect
+  MLIRSCFDialect
+  MLIRTensorDialect
   MLIRMemRefDialect
   MLIRPass
   MLIRShardDialect

diff  --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 398ab88d8199f..b0831dc05abb1 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -11,10 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -71,9 +73,9 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
 }
 
 /// Create operations converting a linear index to a multi-dimensional index.
-static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
-                                             Value linearIndex,
-                                             ValueRange dimensions) {
+[[maybe_unused]] static SmallVector<Value>
+linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex,
+                   ValueRange dimensions) {
   int n = dimensions.size();
   SmallVector<Value> multiIndex(n);
 
@@ -250,46 +252,6 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
   }
 };
 
-struct ConvertProcessMultiIndexOp
-    : public OpConversionPattern<ProcessMultiIndexOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    // Currently converts its linear index to a multi-dimensional index.
-
-    SymbolTableCollection symbolTableCollection;
-    Location loc = op.getLoc();
-    auto gridOp = getGrid(op, symbolTableCollection);
-    // For now we only support static grid shapes
-    if (ShapedType::isDynamicShape(gridOp.getShape()))
-      return failure();
-
-    SmallVector<Value> dims;
-    llvm::transform(
-        gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
-          return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
-        });
-    Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
-    auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
-
-    // optionally extract subset of grid axes
-    auto axes = adaptor.getAxes();
-    if (!axes.empty()) {
-      SmallVector<Value> subIndex;
-      for (auto axis : axes) {
-        subIndex.emplace_back(mIdx[axis]);
-      }
-      mIdx = std::move(subIndex);
-    }
-
-    rewriter.replaceOp(op, mIdx);
-    return success();
-  }
-};
-
 class ConvertProcessLinearIndexOp
     : public OpConversionPattern<ProcessLinearIndexOp> {
 
@@ -919,10 +881,11 @@ struct ConvertShardToMPIPass
     // ...except the global GridOp. GridShapeOp which will get folded later.
     target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
     // Allow all the stuff that our patterns will convert to
-    target.addLegalDialect<
-        BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
-        tensor::TensorDialect, bufferization::BufferizationDialect,
-        linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
+    target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
+                           arith::ArithDialect, tensor::TensorDialect,
+                           bufferization::BufferizationDialect,
+                           linalg::LinalgDialect, memref::MemRefDialect,
+                           affine::AffineDialect, cf::ControlFlowDialect>();
     // Make sure the function signature, calls etc. are legal
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return typeConverter.isSignatureLegal(op.getFunctionType());
@@ -931,9 +894,12 @@ struct ConvertShardToMPIPass
         [&](Operation *op) { return typeConverter.isLegal(op); });
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
-                 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
-                 ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
-                 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+                 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
+                 ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
+                                                                  ctxt);
+    SymbolTableCollection stc;
+    populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
+    populateAllSliceOpLoweringPatterns(patterns, stc);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
         patterns, typeConverter);

diff  --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index 772e66fee5c56..b433b8b0be7b2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -66,7 +66,7 @@ struct ProcessMultiIndexOpLowering
                     [&completeMultiIndex](GridAxis gridAxis) {
                       return completeMultiIndex[gridAxis];
                     });
-    rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
+    rewriter.replaceOp(op, multiIndex);
     return success();
   }
 };
@@ -157,8 +157,7 @@ struct AllSliceOpLowering
                                                  offsets, sizes, strides);
     Value newResult =
         tensor::CastOp::create(builder, op.getResult().getType(), slice);
-    rewriter.replaceAllUsesWith(op.getResult(), newResult);
-
+    rewriter.replaceOp(op, newResult);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 5e20b5a59d927..a0b6bfaf6fd3d 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -5,11 +5,9 @@
 shard.grid @grid0(shape = 3x4x5)
 func.func @process_multi_index() -> (index, index, index) {
   // CHECK: mpi.comm_rank
-  // CHECK-DAG: %[[v4:.*]] = arith.remsi
-  // CHECK-DAG: %[[v0:.*]] = arith.remsi
-  // CHECK-DAG: %[[v1:.*]] = arith.remsi
+  // CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index 
   %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
-  // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
+  // CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
@@ -80,6 +78,23 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   }
 }
 
+// -----
+// CHECK: shard.grid @grid0
+module {
+  shard.grid @grid0(shape = 3x4x5)
+  // CHECK-LABEL: func @all_slice
+  func.func @all_slice(%arg0 : tensor<3x5xf32>) -> tensor<3x1xf32> {
+    // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vretval:%.*]], [[vrank:%.*]] = mpi.comm_rank([[v0]]) : !mpi.retval, i32
+    // CHECK: [[v1:%.*]] = arith.index_cast [[vrank]] : i32 to index
+    // CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
+    // CHECK: [[vextracted_slice:%.*]] = tensor.extract_slice
+    // CHECK-SAME: [0, [[v2]]#2] [3, 1] [1, 1] : tensor<3x5xf32> to tensor<3x1xf32>
+    %1 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1 : tensor<3x5xf32> -> tensor<3x1xf32>
+    return %1 : tensor<3x1xf32>
+  }
+}
+
 // -----
 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   shard.grid @grid0(shape = 3x4x5)


        


More information about the Mlir-commits mailing list