[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 10 04:13:47 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)

@<!-- -->yaochengji @<!-- -->AntonLydike 


---

Patch is 48.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124724.diff


17 Files Affected:

- (added) mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h (+23) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+8-5) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+23-1) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+3-1) 
- (modified) mlir/include/mlir/InitAllDialects.h (+2) 
- (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+3) 
- (added) mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp (+99) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+150-32) 
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+30-22) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+3-2) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+26-17) 
- (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+13-8) 
- (added) mlir/test/Dialect/Arith/mesh-spmdize.mlir (+17) 
- (added) mlir/test/Dialect/Arith/sharding-propagation.mlir (+54) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+39-1) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+10) 
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..5addffbe571bee1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 75cb096130ca6e4..7de7842baf98abf 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,7 +51,7 @@ class MeshSharding {
   SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
-  MeshSharding() = default;
+  MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
   MeshSharding(Value rhs);
   static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
                           ArrayRef<MeshAxesAttr> split_axes_,
@@ -62,7 +62,7 @@ class MeshSharding {
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
                           ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
-  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
+// Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                         OpOperand &operand,
-                                         OpBuilder &builder);
+// Return the target ShardOP (new or existing).
+ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                            OpOperand &operand,
+                                            OpBuilder &builder,
+                                            ShardOp newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 6039e61a93fadc5..031e6f63bcb42cc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
     Op<Mesh_Dialect, mnemonic, traits> {
 }
 
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
   let summary = "Description of a device/process mesh.";
   let description = [{
     The mesh.mesh operation is a symbol operation that identifies a specific
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
+    OpBuilder<(ins "llvm::StringRef":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
+    )>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
+def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+  let summary = "Get the sharding of the given tensor.";
+  let description = [{
+    This operation returns the sharding of the given tensor as a MeshSharding.
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$source
+  );
+  let results = (outs
+    Mesh_Sharding:$result
+  );
+  let assemblyFormat = [{
+    $source attr-dict `:` type($source) `->` type($result)
+  }];
+}
+
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
   let summary = "Get the shard shape of a given process/device.";
   let description = [{
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
       (`annotate_for_users` $annotate_for_users^)?
       attr-dict `:` type($result)
   }];
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index b4d25cef05a7b96..14aad7f9f6783d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -36,7 +36,9 @@ struct ShardingOption {
   bool empty = false;
   ShardingOption() = default;
   ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
-      : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+      : shardingArray(std::move(shardingArray)), mesh(mesh) {
+    assert(this->mesh);
+  }
   static ShardingOption makeEmpty() {
     auto res = ShardingOption();
     res.empty = true;
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0da82825c82878a..33bc89279c08c32 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
   arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+  arith::registerShardingInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35befe7de2..f96bda603baa63d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   ExpandOps.cpp
   IntRangeOptimizations.cpp
   ReifyValueBounds.cpp
+  ShardingInterfaceImpl.cpp
   UnsignedWhenEquivalent.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
   MLIRInferIntRangeInterface
   MLIRIR
   MLIRMemRefDialect
+  MLIRMeshDialect
   MLIRPass
+  MLIRShardingInterface
   MLIRTensorDialect
   MLIRTransforms
   MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f31db4906775687
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.constant
+struct ConstantShardingInterface
+    : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+                                              ConstantOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto ndims = 0;
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ndims = type.getRank();
+    }
+    return SmallVector<utils::IteratorType>(ndims,
+                                            utils::IteratorType::parallel);
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+                                           type.getRank(), op->getContext())});
+    }
+    return {};
+  }
+
+  // Indicate failure if no result sharding exists.
+  // Otherwise mirror result sharding if it is a tensor constant.
+  // Otherwise return replication option.
+  FailureOr<ShardingOption>
+  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+                    ArrayRef<MeshSharding> resultShardings) const {
+    if (!resultShardings[0]) {
+      return failure();
+    }
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+      for (auto [i, axes] :
+           llvm::enumerate(resultShardings[0].getSplitAxes())) {
+        axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+      }
+      return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+    }
+    return ShardingOption({}, resultShardings[0].getMeshAttr());
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    auto cOp = cast<ConstantOp>(op);
+    auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
+    if (value) {
+      if (!value.isSplat() || !resultShardings[0]) {
+        // Currently non-splat constants are not supported.
+        return failure();
+      }
+      auto sharding = resultShardings[0];
+      auto newType = cast<RankedTensorType>(shardType(
+          cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+          sharding));
+      auto newValue = value.resizeSplat(newType);
+      auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+      spmdizationMap.map(op->getResult(0), newOp.getResult());
+      spmdizationMap.map(op, newOp.getOperation());
+    } else {
+      // `clone` will populate the mapping of old to new results.
+      (void)builder.clone(*op, spmdizationMap);
+    }
+    return success();
+  }
+};
+} // namespace
+
+void mlir::arith::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+    ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 33460ff25e9e45d..c789fc527e3f680 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
                        ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
+  // 0d tensors cannot be sharded and must get replicated
+  if (inShape.empty()) {
+    assert(outShape.empty());
+    return;
+  }
+
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
@@ -269,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                     OpOperand &operand,
-                                                     OpBuilder &builder) {
+ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                                        OpOperand &operand,
+                                                        OpBuilder &builder,
+                                                        ShardOp newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
   Operation *operandOp = operand.getOwner();
@@ -279,14 +286,17 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
-    // No need for anything the correct sharding is already set.
-    return;
+    // No need for anything if the correct sharding is already set.
+    return newShardOp ? newShardOp : shardOp;
   }
 
-  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
-  auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
-                              /*annotate_for_users*/ false);
+  if (!newShardOp) {
+    auto shardingOp =
+        builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+    newShardOp =
+        builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+                                /*annotate_for_users*/ false);
+  }
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -294,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
       });
 
   if (!shardOp || shardOp.getAnnotateForUsers()) {
-    return;
+    return newShardOp;
   }
 
-  auto newShardOp2 =
-      builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
-                              /*annotate_for_users*/ true);
+  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+                                             newShardOp.getSharding(),
+                                             /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  return newShardOp;
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
+  ShardOp newShardOp;
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+    newShardOp =
+        maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
   }
 }
 
@@ -316,9 +329,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
-  Operation *operandOp = operand.getOwner();
   Operation *operandSrcOp = operandValue.getDefiningOp();
   bool isBlockArg = !operandSrcOp;
+  {
+    auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
+    assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
+  }
+  if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
+      operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+    return;
+  }
+
+  Operation *operandOp = operand.getOwner();
   ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
 
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -432,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxesAttr> split_axes,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
-                       ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_offsets) {
+                       ArrayRef<int64_t> static_halos,
+                       ArrayRef<int64_t> static_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
-                                     static_sharded_dims_offsets),
-      {});
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
 }
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -453,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
       {}, {}, {}, {});
 }
 
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+                       llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+                       ArrayRef<int64_t> static_halos,
+                       ArrayRef<int64_t> static_offsets) {
+  return build(
+      b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+      MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+}
+
 void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
@@ -579,9 +611,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 namespace {
 // Sharding annotations "halo sizes" and "sharded dims offsets"
 // are a mix of attributes and dynamic values. This canonicalization moves
-// constant values to the respective attribute lists and so minimizes the number
+// constant values to the respective attribute lists, minimizing the number
 // of values.
-class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+// It also removes sharded_dims_sizes and halos if they are effectively "empty".
+class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
 public:
   using OpRewritePattern<ShardingOp>::OpRewritePattern;
 
@@ -593,14 +626,41 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
                                     op.getDynamicShardedDimsOffsets(), b);
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedOffs, /*onlyNo...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/124724


More information about the Mlir-commits mailing list