[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)
Frank Schlimbach
llvmlistbot at llvm.org
Mon Feb 10 04:13:07 PST 2025
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/124724
>From 75e1ab9dc9959d9b7709f184d1bfc9b0297044cb Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 27 Nov 2024 16:38:24 +0100
Subject: [PATCH 1/6] Allowing constant-like operands to ShardingInterface ops
Attaching ShardingInterface to arith::ConstantOp
---
.../Arith/Transforms/ShardingInterfaceImpl.h | 23 +++++
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 +-
mlir/include/mlir/InitAllDialects.h | 2 +
.../Dialect/Arith/Transforms/CMakeLists.txt | 1 +
.../Transforms/ShardingInterfaceImpl.cpp | 99 +++++++++++++++++++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 19 +++-
.../Mesh/Interfaces/ShardingInterface.cpp | 17 ++--
.../Mesh/Transforms/ShardingPropagation.cpp | 3 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 35 ++++---
.../Extensions/MeshShardingExtensions.cpp | 15 +--
mlir/test/Dialect/Arith/mesh-spmdize.cpp | 17 ++++
.../Dialect/Arith/sharding-propagation.mlir | 54 ++++++++++
12 files changed, 251 insertions(+), 36 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
create mode 100644 mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp
create mode 100644 mlir/test/Dialect/Arith/sharding-propagation.mlir
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..5addffbe571bee1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 75cb096130ca6e4..210b82151ede4e8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -62,7 +62,7 @@ class MeshSharding {
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
- ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+ ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0da82825c82878a..33bc89279c08c32 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35befe7de2..30dd84aff120f36 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
+ ShardingInterfaceImpl.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..fc033294eb01b28
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.empty/arith.splat
+struct ConstantShardingInterface
+ : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+ ConstantOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto ndims = 0;
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ndims = type.getRank();
+ }
+ return SmallVector<utils::IteratorType>(ndims,
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+ type.getRank(), op->getContext())});
+ }
+ return {};
+ }
+
+ // Indicate failure if no result sharding exists.
+ // Otherwise mirror result sharding if it is a tensor constant.
+ // Otherwise return replication option.
+ FailureOr<ShardingOption>
+ getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings) const {
+ if (!resultShardings[0]) {
+ return failure();
+ }
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+ for (auto [i, axes] :
+ llvm::enumerate(resultShardings[0].getSplitAxes())) {
+ axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+ }
+ return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+ }
+ return ShardingOption({}, resultShardings[0].getMeshAttr());
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ auto cOp = cast<ConstantOp>(op);
+ auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
+ if (value) {
+ if (!value.isSplat() || !resultShardings[0]) {
+ // Currently non-splat constants are not supported.
+ return failure();
+ }
+ auto sharding = resultShardings[0];
+ auto newType = cast<RankedTensorType>(shardType(
+ cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+ sharding));
+ auto newValue = value.resizeSplat(newType);
+ auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+ spmdizationMap.map(op->getResult(0), newOp.getResult());
+ spmdizationMap.map(op, newOp.getOperation());
+ } else {
+ // `clone` will populate the mapping of old to new results.
+ (void)builder.clone(*op, spmdizationMap);
+ }
+ return success();
+ }
+};
+} // namespace
+
+void mlir::arith::registerShardingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+
+ registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+ ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 33460ff25e9e45d..352bf476e3f570a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -316,9 +316,13 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
- Operation *operandOp = operand.getOwner();
Operation *operandSrcOp = operandValue.getDefiningOp();
bool isBlockArg = !operandSrcOp;
+ if(!isBlockArg && operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+ return;
+ }
+
+ Operation *operandOp = operand.getOwner();
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
if (shardOp && sharding == shardOp.getSharding() &&
@@ -710,8 +714,13 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
MeshSharding::MeshSharding(Value rhs) {
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
assert(shardingOp && "expected sharding op");
- *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
- shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+ auto splitAxes = shardingOp.getSplitAxes().getAxes();
+ auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
+ if(splitAxes.empty() && partialAxes.empty()) {
+ *this = MeshSharding();
+ return;
+ }
+ *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
shardingOp.getPartialType().value_or(ReductionKind::Sum),
shardingOp.getStaticHaloSizes(),
shardingOp.getStaticShardedDimsOffsets(),
@@ -727,6 +736,10 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<int64_t> static_sharded_dims_offsets_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+ if(split_axes_.empty() && partial_axes_.empty()) {
+ return MeshSharding();
+ }
+
MeshSharding res;
res.mesh = mesh_;
res.split_axes.resize(split_axes_.size());
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index c1f4d563d5b42c3..aae2d4ccfeed916 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -168,16 +168,16 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
// check operands and results type
for (Type type : op->getOperandTypes())
- if (!llvm::isa<RankedTensorType>(type))
+ if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
return failure();
for (Type type : op->getResultTypes())
- if (!llvm::isa<RankedTensorType>(type))
+ if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
return failure();
// check loop types
- SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
- if (loopTypes.empty())
- return failure();
+ // SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
+ // if (loopTypes.empty())
+ // return failure();
// check maps
SmallVector<AffineMap> maps = getIndexingMaps();
@@ -448,7 +448,12 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
Value operandValue = opOperand.get();
- auto operandType = cast<RankedTensorType>(operandValue.getType());
+ auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
+ if(!operandType) {
+ if(operandValue.getType().isIntOrIndexOrFloat())
+ return MeshSharding();
+ return failure();
+ }
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4bd3b425219c1ae..f96d54424a2fe8d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -282,11 +282,12 @@ static FailureOr<ShardingOption> selectShardingOption(
// a `mesh.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+ ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
+ (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
return success();
- ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (!shardingOp) {
op->emitOpError() << "sharding interface is not implemented.";
return failure();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 327ea0991e4e1ea..04932f11e6b433f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -636,14 +636,6 @@ shardedBlockArgumentTypes(Block &block,
return res;
}
-void spmdizeTriviallyShardableOperation(Operation &op,
- ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder);
-
static LogicalResult spmdizeOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
@@ -697,14 +689,15 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
- [](OpResult result) {
+ [&op](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
- if (!rankedTensor) {
+ if (!rankedTensor || op.hasTrait<OpTrait::ConstantLike>()) {
+ return MeshSharding();
+ }
+ if (!result.hasOneUse()) {
return MeshSharding();
}
-
- assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
return MeshSharding(shardOp.getSharding());
@@ -765,6 +758,7 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
+
SmallVector<Location> argLocations;
llvm::transform(block.getArguments(), std::back_inserter(argLocations),
[](BlockArgument arg) { return arg.getLoc(); });
@@ -796,8 +790,12 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
// Snapshot the original blocks to not mess up the iteration when adding new
// blocks.
SmallVector<Block *> originalBlocks;
- llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
- [](Block &b) { return &b; });
+ for (Block &b : op.getBlocks()) {
+ if (llvm::any_of(b.getOperations(),
+ [](Operation &op) { return isa<ShardOp>(op); })) {
+ originalBlocks.push_back(&b);
+ }
+ }
for (Block *block : originalBlocks) {
if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
@@ -823,10 +821,11 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
break;
}
}
- assert(returnOp);
- op.setType(FunctionType::get(op->getContext(),
- op.getFunctionBody().front().getArgumentTypes(),
- returnOp->getOperandTypes()));
+ if (returnOp) {
+ op.setType(FunctionType::get(
+ op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
+ returnOp->getOperandTypes()));
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index f3e72abe7516eeb..6bb5d4a66f39eaf 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -22,10 +22,10 @@ using namespace mlir::mesh;
namespace {
-// Sharding of tensor.empty
-struct EmptyOpShardingInterface
- : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
- tensor::EmptyOp> {
+// Sharding of tensor.empty/tensor.splat
+template<typename OpTy>
+struct CreatorOpShardingInterface
+ : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +38,7 @@ struct EmptyOpShardingInterface
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
- return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
+ return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -83,7 +83,7 @@ struct EmptyOpShardingInterface
}
}
newOp =
- builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
+ builder.create<OpTy>(op->getLoc(), shardType, newOperands);
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
@@ -100,6 +100,7 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
+ EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
+ SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
new file mode 100644
index 000000000000000..0688e14b1cf7212
--- /dev/null
+++ b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
@@ -0,0 +1,17 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
+// RUN: %s | FileCheck %s
+
+mesh.mesh @mesh4x4(shape = 4x4)
+
+// CHECK-LABEL: func @test_spmdize_constant
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32>
+// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32
+// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+ %ci = arith.constant 434 : i32
+ return %sharding_annotated_1 : tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir
new file mode 100644
index 000000000000000..19eb340549b0beb
--- /dev/null
+++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
+
+mesh.mesh @mesh4x4(shape = 4x4)
+
+// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
+// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
+func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ %ci = arith.constant 43.4e+00 : f32
+ %o1 = tensor.empty() : tensor<1024x1024xf32>
+ %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+ return %res : tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
+// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %ci = arith.constant 43.4e+00 : f32
+ %o1 = tensor.empty() : tensor<1024x1024xf32>
+ %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+ %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
+ return %sharding_annotated_1 : tensor<1024x1024xf32>
+}
>From ef6671b6c5638ea3dd12f29877c9512341e3456d Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 4 Dec 2024 11:04:59 +0100
Subject: [PATCH 2/6] better handling of replicated tensors
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 +-
.../Mesh/Interfaces/ShardingInterface.h | 4 ++-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 22 +++++++++---
.../Mesh/Interfaces/ShardingInterface.cpp | 36 +++++++++++--------
.../Dialect/Mesh/Transforms/Spmdization.cpp | 3 +-
5 files changed, 45 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 210b82151ede4e8..626f2fcf93b368a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,7 +51,7 @@ class MeshSharding {
SmallVector<Value> dynamic_sharded_dims_offsets;
public:
- MeshSharding() = default;
+ MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index b4d25cef05a7b96..14aad7f9f6783d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -36,7 +36,9 @@ struct ShardingOption {
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
- : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+ : shardingArray(std::move(shardingArray)), mesh(mesh) {
+ assert(this->mesh);
+ }
static ShardingOption makeEmpty() {
auto res = ShardingOption();
res.empty = true;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 352bf476e3f570a..5e342a855d6aef6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
+ // 0d tensors cannot be sharded and must get replicated
+ if (inShape.empty()) {
+ assert(outShape.empty());
+ return;
+ }
+
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
@@ -318,7 +324,12 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Value operandValue = operand.get();
Operation *operandSrcOp = operandValue.getDefiningOp();
bool isBlockArg = !operandSrcOp;
- if(!isBlockArg && operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+ {
+ auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
+ assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
+ }
+ if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
+ operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
return;
}
@@ -711,13 +722,15 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
return !(*this == rhs);
}
+MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
+
MeshSharding::MeshSharding(Value rhs) {
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
if(splitAxes.empty() && partialAxes.empty()) {
- *this = MeshSharding();
+ *this = MeshSharding(shardingOp.getMeshAttr());
return;
}
*this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
@@ -736,12 +749,11 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<int64_t> static_sharded_dims_offsets_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+ MeshSharding res(mesh_);
if(split_axes_.empty() && partial_axes_.empty()) {
- return MeshSharding();
+ return res;
}
- MeshSharding res;
- res.mesh = mesh_;
res.split_axes.resize(split_axes_.size());
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
res.split_axes[i] =
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index aae2d4ccfeed916..aaffe759b0cef08 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -286,18 +286,22 @@ mesh::detail::defaultGetShardingOption(Operation *op,
continue;
AffineMap map = maps[numOperands + shardingIt.index()];
anyShardingInResultsOrOperands = true;
- // Handle the split axes: calculate the corresponding loop index for each
- // split axes sub-array, and then store the sub-array to
- // shardingOption[index]
- for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
- AffineExpr expr = std::get<0>(it);
- ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
- auto dim = cast<AffineDimExpr>(expr);
- unsigned index = dim.getPosition();
- visitedLoopIndices.insert(index);
- if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
- axes, index)))
- return failure();
+ if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
+ shardingOption.mesh = shardAttr.getMeshAttr();
+ } else {
+ // Handle the split axes: calculate the corresponding loop index for each
+ // split axes sub-array, and then store the sub-array to
+ // shardingOption[index]
+ for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+ AffineExpr expr = std::get<0>(it);
+ ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+ auto dim = cast<AffineDimExpr>(expr);
+ unsigned index = dim.getPosition();
+ visitedLoopIndices.insert(index);
+ if (failed(fillShardingOption(op, shardingOption,
+ shardAttr.getMeshAttr(), axes, index)))
+ return failure();
+ }
}
// Handle the partial axes: at this stage, the exact loop index/indices
@@ -323,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
if (!shardAttr)
continue;
- anyShardingInResultsOrOperands = true;
+ anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
AffineMap map = maps[shardingIt.index()];
unsigned numDims = map.getNumDims();
@@ -454,6 +458,10 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
return MeshSharding();
return failure();
}
+ // 0d tensors cannot be sharded and must get replicated
+ if (operandType.getRank() == 0) {
+ return MeshSharding(shardingOption.mesh);
+ }
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
@@ -584,7 +592,7 @@ static bool
isValueCompatibleWithFullReplicationSharding(Value value,
MeshSharding sharding) {
if (isa<RankedTensorType>(value.getType())) {
- return sharding && isFullReplication(sharding);
+ return isFullReplication(sharding);
}
return !sharding;
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 04932f11e6b433f..27297a8be5d069f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -561,7 +561,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
- if (sourceSharding == targetSharding) {
+ if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
+ isFullReplication(targetSharding))) {
return sourceShard;
}
>From 6a633617bdfc654817b7d2f1d5cb13d84bb84d1a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 5 Dec 2024 12:55:23 +0100
Subject: [PATCH 3/6] canonicalize ShardOp and ShardingOp
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 3 +-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 87 +++++++++++++++++--
.../Dialect/Mesh/Transforms/Spmdization.cpp | 2 +-
mlir/test/Dialect/Mesh/canonicalization.mlir | 40 ++++++++-
4 files changed, 122 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 6039e61a93fadc5..531020930768e6b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let summary = "Description of a device/process mesh.";
let description = [{
The mesh.mesh operation is a symbol operation that identifies a specific
@@ -460,6 +460,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
}];
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 5e342a855d6aef6..6a1498c0f681483 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -594,9 +594,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
namespace {
// Sharding annotations "halo sizes" and "sharded dims offsets"
// are a mix of attributes and dynamic values. This canonicalization moves
-// constant values to the respective attribute lists and so minimizes the number
+// constant values to the respective attribute lists, minimizing the number
// of values.
-class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+// It also removes sharded_dims_sizes and halos if they are effectively "empty".
+class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
public:
using OpRewritePattern<ShardingOp>::OpRewritePattern;
@@ -608,14 +609,39 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
op.getDynamicShardedDimsOffsets(), b);
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
- failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
- return failure();
- }
+ bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
+ succeeded(foldDynamicIndexList(mixedOffs, true));
auto halos = decomposeMixedValues(mixedHalos);
auto offs = decomposeMixedValues(mixedOffs);
+ if (halos.second.empty() && !halos.first.empty()) {
+ if (halos.first[0] == 0 && llvm::all_equal(halos.first)) {
+ halos.first.clear();
+ modified = true;
+ }
+ }
+
+ if (offs.second.empty() && !offs.first.empty()) {
+ assert(offs.first.size() >= 2);
+ auto diff = offs.first[1] - offs.first[0];
+ bool all_same = offs.first.size() > 2;
+ for (auto i = 2u; i < offs.first.size(); ++i) {
+ if (offs.first[i] - offs.first[i - 1] != diff) {
+ all_same = false;
+ break;
+ }
+ }
+ if (all_same) {
+ offs.first.clear();
+ modified = true;
+ }
+ }
+
+ if (!modified) {
+ return failure();
+ }
+
op.setStaticHaloSizes(halos.first);
op.getDynamicHaloSizesMutable().assign(halos.second);
op.setStaticShardedDimsOffsets(offs.first);
@@ -628,7 +654,7 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
mlir::MLIRContext *context) {
- results.add<FoldDynamicLists>(context);
+ results.add<NormalizeSharding>(context);
}
//===----------------------------------------------------------------------===//
@@ -796,6 +822,53 @@ void ShardOp::getAsmResultNames(
setNameFn(getResult(), "sharding_annotated");
}
+namespace {
+// Determine if the given ShardOp is a duplicate of another ShardOp
+// on the same value. This can happen if constant values are sharded.
+class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
+public:
+ using OpRewritePattern<ShardOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
+ // Get the use-list of the value being sharded and check if it has more than
+ // one use.
+ Value value = op.getSrc();
+ if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
+ return failure();
+ }
+
+ // Iterate through the uses of the value to find a duplicate ShardOp.
+ for (auto &use : value.getUses()) {
+ if (use.getOwner() != op.getOperation()) {
+ auto otherOp = dyn_cast<ShardOp>(use.getOwner());
+ if (!otherOp || !otherOp->isBeforeInBlock(op)) {
+ return failure();
+ }
+ // Create a MeshSharding object for the current and the other ShardOp
+ // If the two are equal replace current op with the other op.
+ MeshSharding currentSharding(op.getSharding());
+ MeshSharding otherSharding(otherOp.getSharding());
+ if (currentSharding == otherSharding) {
+ b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
+ b.eraseOp(op.getOperation());
+ } else {
+ // use the other sharding as input for op
+ op.getSrcMutable().assign(otherOp.getResult());
+ }
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+ mlir::MLIRContext *context) {
+ results.add<FoldDuplicateShardOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 27297a8be5d069f..e6fe0fd5d1e8789 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -693,7 +693,7 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
[&op](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
- if (!rankedTensor || op.hasTrait<OpTrait::ConstantLike>()) {
+ if (!rankedTensor) {
return MeshSharding();
}
if (!result.hasOneUse()) {
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index f0112d689805d32..aff07bbf8a21413 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -207,4 +207,42 @@ func.func @test_shard_offs() -> !mesh.sharding {
// CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
return %sharding : !mesh.sharding
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops
+func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
+ %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+ %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+ %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+ %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+ return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops_diff
+func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+ %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
+ %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+ // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
+ %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
+ %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+ return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
>From d52ca9a108e7b3e4d3d3dd42d7482e5162035df5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 19 Dec 2024 13:19:32 +0100
Subject: [PATCH 4/6] sharding propagation: add only one shardop for each
result
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 9 ++++--
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 33 +++++++++++++--------
2 files changed, 26 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 626f2fcf93b368a..7de7842baf98abf 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
+// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder);
+// Return the target ShardOP (new or existing).
+ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6a1498c0f681483..2fff67c44a8ac84 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -275,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder) {
+ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
@@ -286,13 +287,16 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
- return;
+ return newShardOp ? newShardOp : shardOp;
}
- auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
- auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ false);
+ if (!newShardOp) {
+ auto shardingOp =
+ builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+ newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+ /*annotate_for_users*/ false);
+ }
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -300,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
});
if (!shardOp || shardOp.getAnnotateForUsers()) {
- return;
+ return newShardOp;
}
- auto newShardOp2 =
- builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
- /*annotate_for_users*/ true);
+ auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+ newShardOp.getSharding(),
+ /*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+ return newShardOp;
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
+ ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
- maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ newShardOp =
+ maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}
>From 3c76df3d4552953b0d5fa6719b31d68796fda199 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 13 Jan 2025 15:36:54 +0100
Subject: [PATCH 5/6] Adding sharding extraction operation and op tests and
handling GetShardingOp in ShardingPropagation
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 21 ++++++++++++++++++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 22 ++++++++++++++-----
.../Mesh/Transforms/ShardingPropagation.cpp | 2 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 9 ++++++++
mlir/test/Dialect/Mesh/ops.mlir | 10 +++++++++
mlir/test/Dialect/Mesh/spmdization.mlir | 14 ++++++++++++
6 files changed, 71 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 531020930768e6b..031e6f63bcb42cc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
+ OpBuilder<(ins "llvm::StringRef":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
+ )>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
+def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+ let summary = "Get the sharding of the given tensor.";
+ let description = [{
+ This operation returns the sharding of the given tensor as a MeshSharding.
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$source
+ );
+ let results = (outs
+ Mesh_Sharding:$result
+ );
+ let assemblyFormat = [{
+ $source attr-dict `:` type($source) `->` type($result)
+ }];
+}
+
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 2fff67c44a8ac84..f84d46704852228 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
- ArrayRef<int64_t> static_halo_sizes,
- ArrayRef<int64_t> static_sharded_dims_offsets) {
+ ArrayRef<int64_t> static_halos,
+ ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(),
- static_sharded_dims_offsets),
- {});
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
{}, {}, {}, {});
}
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+ ArrayRef<int64_t> static_halos,
+ ArrayRef<int64_t> static_offsets) {
+ return build(
+ b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+ MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+}
+
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index f96d54424a2fe8d..8c989cce634064a 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -285,7 +285,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
- llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
+ llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
return success();
if (!shardingOp) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e6fe0fd5d1e8789..4ec8bbc0dff7d44 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -738,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
if (isa<ShardingOp>(op)) {
return success();
}
+ if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
+ auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp) {
+ return op.emitError("expected a shard op as source of get_sharding");
+ }
+ auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
+ spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
+ return success();
+ }
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
if (shardOp) {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 978de4939ee77c2..dae21655afb23ec 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -164,6 +164,16 @@ func.func @mesh_shard_shape() {
return
}
+// CHECK-LABEL: func @mesh_get_sharding
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
+ // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
+ %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
+ return %0 : !mesh.sharding
+}
+
// CHECK-LABEL: func @mesh_shape
func.func @mesh_shape() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index c1b96fda0f4a741..59f7162e21013db 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -4,6 +4,20 @@
mesh.mesh @mesh_1d(shape = 2)
+// CHECK-LABEL: func @return_sharding
+func.func @return_sharding(
+ // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
+) -> (tensor<2xf32>, !mesh.sharding) {
+ %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
+ %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
+ // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
+ return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
+}
+
// CHECK-LABEL: func @full_replication
func.func @full_replication(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
>From 508095ac58bd85dfd6a0cdc7bad93ef57fbc8610 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 28 Jan 2025 10:39:46 +0100
Subject: [PATCH 6/6] comments
adding libs
clang-format
renaming mesh-spmdize.cpp -> mesh-spmdize.mlir and fixing format
---
.../Dialect/Arith/Transforms/CMakeLists.txt | 2 ++
.../Arith/Transforms/ShardingInterfaceImpl.cpp | 2 +-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 9 ++++++---
.../Mesh/Interfaces/ShardingInterface.cpp | 9 ++-------
.../Dialect/Mesh/Transforms/Spmdization.cpp | 2 +-
.../Extensions/MeshShardingExtensions.cpp | 18 +++++++++++-------
mlir/test/Dialect/Arith/mesh-spmdize.cpp | 17 -----------------
mlir/test/Dialect/Arith/mesh-spmdize.mlir | 17 +++++++++++++++++
8 files changed, 40 insertions(+), 36 deletions(-)
delete mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp
create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 30dd84aff120f36..f96bda603baa63d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -27,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
+ MLIRMeshDialect
MLIRPass
+ MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index fc033294eb01b28..f31db4906775687 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -19,7 +19,7 @@ using namespace mlir::mesh;
namespace {
-// Sharding of arith.empty/arith.splat
+// Sharding of arith.constant
struct ConstantShardingInterface
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
ConstantOp> {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f84d46704852228..c789fc527e3f680 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -286,7 +286,7 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
- // No need for anything the correct sharding is already set.
+ // No need for anything if the correct sharding is already set.
return newShardOp ? newShardOp : shardOp;
}
@@ -639,6 +639,8 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
}
}
+ // Remove sharded dims offsets if they are effectively the default values,
+ // e.g. if they define equi-distance between all neighboring shards.
if (offs.second.empty() && !offs.first.empty()) {
assert(offs.first.size() >= 2);
auto diff = offs.first[1] - offs.first[0];
@@ -772,7 +774,8 @@ MeshSharding::MeshSharding(Value rhs) {
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
- if(splitAxes.empty() && partialAxes.empty()) {
+ // If splitAxes and partialAxes are empty, use "empty" constructor.
+ if (splitAxes.empty() && partialAxes.empty()) {
*this = MeshSharding(shardingOp.getMeshAttr());
return;
}
@@ -793,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
MeshSharding res(mesh_);
- if(split_axes_.empty() && partial_axes_.empty()) {
+ if (split_axes_.empty() && partial_axes_.empty()) {
return res;
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index aaffe759b0cef08..f427d004c558ff6 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -174,11 +174,6 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
return failure();
- // check loop types
- // SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
- // if (loopTypes.empty())
- // return failure();
-
// check maps
SmallVector<AffineMap> maps = getIndexingMaps();
if (maps.empty())
@@ -453,8 +448,8 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
- if(!operandType) {
- if(operandValue.getType().isIntOrIndexOrFloat())
+ if (!operandType) {
+ if (operandValue.getType().isIntOrIndexOrFloat())
return MeshSharding();
return failure();
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 4ec8bbc0dff7d44..601af0200e78514 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -690,7 +690,7 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
- [&op](OpResult result) {
+ [](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 6bb5d4a66f39eaf..b2acbf20b3fb935 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -23,9 +23,10 @@ using namespace mlir::mesh;
namespace {
// Sharding of tensor.empty/tensor.splat
-template<typename OpTy>
+template <typename OpTy>
struct CreatorOpShardingInterface
- : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
+ : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
+ OpTy> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
- return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
+ return SmallVector<AffineMap>(
+ op->getNumOperands() + op->getNumResults(),
+ {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
}
}
- newOp =
- builder.create<OpTy>(op->getLoc(), shardType, newOperands);
+ newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
@@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
- SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
+ EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
+ *ctx);
+ SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
+ *ctx);
});
}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
deleted file mode 100644
index 0688e14b1cf7212..000000000000000
--- a/mlir/test/Dialect/Arith/mesh-spmdize.cpp
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh4x4(shape = 4x4)
-
-// CHECK-LABEL: func @test_spmdize_constant
-// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32>
-// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32
-// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
- %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
- %ci = arith.constant 434 : i32
- return %sharding_annotated_1 : tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir
new file mode 100644
index 000000000000000..6b55dd533a92c27
--- /dev/null
+++ b/mlir/test/Dialect/Arith/mesh-spmdize.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
+// RUN: %s | FileCheck %s
+
+mesh.mesh @mesh4x4(shape = 4x4)
+
+// CHECK-LABEL: func @test_spmdize_constant
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
+// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
+// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+ %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+ %ci = arith.constant 434 : i32
+ return %sharding_annotated_1 : tensor<1024x1024xf32>
+}
More information about the Mlir-commits
mailing list