[Mlir-commits] [mlir] d12019d - [mlir][shard, mpi] lowering shard.all_slice in shard-to-mpi (#176438)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 06:37:39 PST 2026
Author: Frank Schlimbach
Date: 2026-01-19T15:37:35+01:00
New Revision: d12019d5e52fd36651f29d5485875e847990172d
URL: https://github.com/llvm/llvm-project/commit/d12019d5e52fd36651f29d5485875e847990172d
DIFF: https://github.com/llvm/llvm-project/commit/d12019d5e52fd36651f29d5485875e847990172d.diff
LOG: [mlir][shard,mpi] lowering shard.all_slice in shard-to-mpi (#176438)
Lowering shard.all_slice in shard-to-mpi and reusing lowering for
shard.processmultindex.
Added:
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..1096338534416 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1001,10 +1001,12 @@ def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
let dependentDialects = [
"affine::AffineDialect",
"arith::ArithDialect",
+ "bufferization::BufferizationDialect",
+ "cf::ControlFlowDialect",
"memref::MemRefDialect",
"mpi::MPIDialect",
"scf::SCFDialect",
- "bufferization::BufferizationDialect"
+ "tensor::TensorDialect"
];
}
diff --git a/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
index 564f36fd20abb..dfa8e77bb13ef 100644
--- a/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
@@ -15,6 +15,11 @@ add_mlir_conversion_library(MLIRShardToMPI
MLIRFuncDialect
MLIRIR
MLIRLinalgTransforms
+ MLIRAffineDialect
+ MLIRArithDialect
+ MLIRControlFlowDialect
+ MLIRSCFDialect
+ MLIRTensorDialect
MLIRMemRefDialect
MLIRPass
MLIRShardDialect
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 398ab88d8199f..b0831dc05abb1 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -11,10 +11,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -71,9 +73,9 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
}
/// Create operations converting a linear index to a multi-dimensional index.
-static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
- Value linearIndex,
- ValueRange dimensions) {
+[[maybe_unused]] static SmallVector<Value>
+linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex,
+ ValueRange dimensions) {
int n = dimensions.size();
SmallVector<Value> multiIndex(n);
@@ -250,46 +252,6 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
}
};
-struct ConvertProcessMultiIndexOp
- : public OpConversionPattern<ProcessMultiIndexOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
-
- // Currently converts its linear index to a multi-dimensional index.
-
- SymbolTableCollection symbolTableCollection;
- Location loc = op.getLoc();
- auto gridOp = getGrid(op, symbolTableCollection);
- // For now we only support static grid shapes
- if (ShapedType::isDynamicShape(gridOp.getShape()))
- return failure();
-
- SmallVector<Value> dims;
- llvm::transform(
- gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
- });
- Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
- auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
-
- // optionally extract subset of grid axes
- auto axes = adaptor.getAxes();
- if (!axes.empty()) {
- SmallVector<Value> subIndex;
- for (auto axis : axes) {
- subIndex.emplace_back(mIdx[axis]);
- }
- mIdx = std::move(subIndex);
- }
-
- rewriter.replaceOp(op, mIdx);
- return success();
- }
-};
-
class ConvertProcessLinearIndexOp
: public OpConversionPattern<ProcessLinearIndexOp> {
@@ -919,10 +881,11 @@ struct ConvertShardToMPIPass
// ...except the global GridOp. GridShapeOp which will get folded later.
target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
// Allow all the stuff that our patterns will convert to
- target.addLegalDialect<
- BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
- tensor::TensorDialect, bufferization::BufferizationDialect,
- linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
+ target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
+ arith::ArithDialect, tensor::TensorDialect,
+ bufferization::BufferizationDialect,
+ linalg::LinalgDialect, memref::MemRefDialect,
+ affine::AffineDialect, cf::ControlFlowDialect>();
// Make sure the function signature, calls etc. are legal
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType());
@@ -931,9 +894,12 @@ struct ConvertShardToMPIPass
[&](Operation *op) { return typeConverter.isLegal(op); });
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
- ConvertProcessMultiIndexOp, ConvertGetShardingOp,
- ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
- ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+ ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
+ ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
+ ctxt);
+ SymbolTableCollection stc;
+ populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
+ populateAllSliceOpLoweringPatterns(patterns, stc);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index 772e66fee5c56..b433b8b0be7b2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -66,7 +66,7 @@ struct ProcessMultiIndexOpLowering
[&completeMultiIndex](GridAxis gridAxis) {
return completeMultiIndex[gridAxis];
});
- rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
+ rewriter.replaceOp(op, multiIndex);
return success();
}
};
@@ -157,8 +157,7 @@ struct AllSliceOpLowering
offsets, sizes, strides);
Value newResult =
tensor::CastOp::create(builder, op.getResult().getType(), slice);
- rewriter.replaceAllUsesWith(op.getResult(), newResult);
-
+ rewriter.replaceOp(op, newResult);
return success();
}
};
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 5e20b5a59d927..a0b6bfaf6fd3d 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -5,11 +5,9 @@
shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK: mpi.comm_rank
- // CHECK-DAG: %[[v4:.*]] = arith.remsi
- // CHECK-DAG: %[[v0:.*]] = arith.remsi
- // CHECK-DAG: %[[v1:.*]] = arith.remsi
+ // CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index
%0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
- // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
+ // CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@@ -80,6 +78,23 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
}
}
+// -----
+// CHECK: shard.grid @grid0
+module {
+ shard.grid @grid0(shape = 3x4x5)
+ // CHECK-LABEL: func @all_slice
+ func.func @all_slice(%arg0 : tensor<3x5xf32>) -> tensor<3x1xf32> {
+ // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vretval:%.*]], [[vrank:%.*]] = mpi.comm_rank([[v0]]) : !mpi.retval, i32
+ // CHECK: [[v1:%.*]] = arith.index_cast [[vrank]] : i32 to index
+ // CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
+ // CHECK: [[vextracted_slice:%.*]] = tensor.extract_slice
+ // CHECK-SAME: [0, [[v2]]#2] [3, 1] [1, 1] : tensor<3x5xf32> to tensor<3x1xf32>
+ %1 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1 : tensor<3x5xf32> -> tensor<3x1xf32>
+ return %1 : tensor<3x1xf32>
+ }
+}
+
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
shard.grid @grid0(shape = 3x4x5)
More information about the Mlir-commits
mailing list