[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 10 04:13:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)
@<!-- -->yaochengji @<!-- -->AntonLydike
---
Patch is 48.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124724.diff
17 Files Affected:
- (added) mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h (+23)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+8-5)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+23-1)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+3-1)
- (modified) mlir/include/mlir/InitAllDialects.h (+2)
- (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+3)
- (added) mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp (+99)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+150-32)
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+30-22)
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+3-2)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+26-17)
- (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+13-8)
- (added) mlir/test/Dialect/Arith/mesh-spmdize.mlir (+17)
- (added) mlir/test/Dialect/Arith/sharding-propagation.mlir (+54)
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+39-1)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+10)
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..5addffbe571bee1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 75cb096130ca6e4..7de7842baf98abf 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,7 +51,7 @@ class MeshSharding {
SmallVector<Value> dynamic_sharded_dims_offsets;
public:
- MeshSharding() = default;
+ MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
@@ -62,7 +62,7 @@ class MeshSharding {
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
- ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+ ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
+// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder);
+// Return the target ShardOP (new or existing).
+ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 6039e61a93fadc5..031e6f63bcb42cc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let summary = "Description of a device/process mesh.";
let description = [{
The mesh.mesh operation is a symbol operation that identifies a specific
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
+ OpBuilder<(ins "llvm::StringRef":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
+ )>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
+def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+ let summary = "Get the sharding of the given tensor.";
+ let description = [{
+ This operation returns the sharding of the given tensor as a MeshSharding.
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$source
+ );
+ let results = (outs
+ Mesh_Sharding:$result
+ );
+ let assemblyFormat = [{
+ $source attr-dict `:` type($source) `->` type($result)
+ }];
+}
+
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
}];
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index b4d25cef05a7b96..14aad7f9f6783d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -36,7 +36,9 @@ struct ShardingOption {
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
- : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+ : shardingArray(std::move(shardingArray)), mesh(mesh) {
+ assert(this->mesh);
+ }
static ShardingOption makeEmpty() {
auto res = ShardingOption();
res.empty = true;
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0da82825c82878a..33bc89279c08c32 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35befe7de2..f96bda603baa63d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
+ ShardingInterfaceImpl.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
+ MLIRMeshDialect
MLIRPass
+ MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f31db4906775687
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.constant
+struct ConstantShardingInterface
+ : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+ ConstantOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto ndims = 0;
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ndims = type.getRank();
+ }
+ return SmallVector<utils::IteratorType>(ndims,
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+ type.getRank(), op->getContext())});
+ }
+ return {};
+ }
+
+ // Indicate failure if no result sharding exists.
+ // Otherwise mirror result sharding if it is a tensor constant.
+ // Otherwise return replication option.
+ FailureOr<ShardingOption>
+ getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings) const {
+ if (!resultShardings[0]) {
+ return failure();
+ }
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+ for (auto [i, axes] :
+ llvm::enumerate(resultShardings[0].getSplitAxes())) {
+ axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+ }
+ return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+ }
+ return ShardingOption({}, resultShardings[0].getMeshAttr());
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ auto cOp = cast<ConstantOp>(op);
+ auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
+ if (value) {
+ if (!value.isSplat() || !resultShardings[0]) {
+ // Currently non-splat constants are not supported.
+ return failure();
+ }
+ auto sharding = resultShardings[0];
+ auto newType = cast<RankedTensorType>(shardType(
+ cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+ sharding));
+ auto newValue = value.resizeSplat(newType);
+ auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+ spmdizationMap.map(op->getResult(0), newOp.getResult());
+ spmdizationMap.map(op, newOp.getOperation());
+ } else {
+ // `clone` will populate the mapping of old to new results.
+ (void)builder.clone(*op, spmdizationMap);
+ }
+ return success();
+ }
+};
+} // namespace
+
+void mlir::arith::registerShardingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+
+ registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+ ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 33460ff25e9e45d..c789fc527e3f680 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
+ // 0d tensors cannot be sharded and must get replicated
+ if (inShape.empty()) {
+ assert(outShape.empty());
+ return;
+ }
+
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
@@ -269,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder) {
+ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
@@ -279,14 +286,17 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
- // No need for anything the correct sharding is already set.
- return;
+ // No need for anything if the correct sharding is already set.
+ return newShardOp ? newShardOp : shardOp;
}
- auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
- auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ false);
+ if (!newShardOp) {
+ auto shardingOp =
+ builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+ newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+ /*annotate_for_users*/ false);
+ }
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -294,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
});
if (!shardOp || shardOp.getAnnotateForUsers()) {
- return;
+ return newShardOp;
}
- auto newShardOp2 =
- builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
- /*annotate_for_users*/ true);
+ auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+ newShardOp.getSharding(),
+ /*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+ return newShardOp;
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
+ ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
- maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ newShardOp =
+ maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}
@@ -316,9 +329,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
- Operation *operandOp = operand.getOwner();
Operation *operandSrcOp = operandValue.getDefiningOp();
bool isBlockArg = !operandSrcOp;
+ {
+ auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
+ assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
+ }
+ if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
+ operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+ return;
+ }
+
+ Operation *operandOp = operand.getOwner();
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
if (shardOp && sharding == shardOp.getSharding() &&
@@ -432,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
- ArrayRef<int64_t> static_halo_sizes,
- ArrayRef<int64_t> static_sharded_dims_offsets) {
+ ArrayRef<int64_t> static_halos,
+ ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(),
- static_sharded_dims_offsets),
- {});
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -453,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
{}, {}, {}, {});
}
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+ ArrayRef<int64_t> static_halos,
+ ArrayRef<int64_t> static_offsets) {
+ return build(
+ b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+ MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+}
+
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
@@ -579,9 +611,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
namespace {
// Sharding annotations "halo sizes" and "sharded dims offsets"
// are a mix of attributes and dynamic values. This canonicalization moves
-// constant values to the respective attribute lists and so minimizes the number
+// constant values to the respective attribute lists, minimizing the number
// of values.
-class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+// It also removes sharded_dims_sizes and halos if they are effectively "empty".
+class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
public:
using OpRewritePattern<ShardingOp>::OpRewritePattern;
@@ -593,14 +626,41 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
op.getDynamicShardedDimsOffsets(), b);
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
- failed(foldDynamicIndexList(mixedOffs, /*onlyNo...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/124724
More information about the Mlir-commits
mailing list