[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

Frank Schlimbach llvmlistbot at llvm.org
Wed Feb 12 02:58:41 PST 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/124724

>From 75e1ab9dc9959d9b7709f184d1bfc9b0297044cb Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 27 Nov 2024 16:38:24 +0100
Subject: [PATCH 01/10] Allowing constant-like operands to ShardingInterface
 ops Attaching ShardingInterface to arith::ConstantOp

---
 .../Arith/Transforms/ShardingInterfaceImpl.h  | 23 +++++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  2 +-
 mlir/include/mlir/InitAllDialects.h           |  2 +
 .../Dialect/Arith/Transforms/CMakeLists.txt   |  1 +
 .../Transforms/ShardingInterfaceImpl.cpp      | 99 +++++++++++++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 19 +++-
 .../Mesh/Interfaces/ShardingInterface.cpp     | 17 ++--
 .../Mesh/Transforms/ShardingPropagation.cpp   |  3 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 35 ++++---
 .../Extensions/MeshShardingExtensions.cpp     | 15 +--
 mlir/test/Dialect/Arith/mesh-spmdize.cpp      | 17 ++++
 .../Dialect/Arith/sharding-propagation.mlir   | 54 ++++++++++
 12 files changed, 251 insertions(+), 36 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp
 create mode 100644 mlir/test/Dialect/Arith/sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
new file mode 100644
index 0000000000000..5addffbe571be
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 75cb096130ca6..210b82151ede4 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -62,7 +62,7 @@ class MeshSharding {
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
                           ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
-  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0da82825c8287..33bc89279c08c 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
   arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+  arith::registerShardingInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35befe7d..30dd84aff120f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   ExpandOps.cpp
   IntRangeOptimizations.cpp
   ReifyValueBounds.cpp
+  ShardingInterfaceImpl.cpp
   UnsignedWhenEquivalent.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
new file mode 100644
index 0000000000000..fc033294eb01b
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.empty/arith.splat
+struct ConstantShardingInterface
+    : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+                                              ConstantOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto ndims = 0;
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ndims = type.getRank();
+    }
+    return SmallVector<utils::IteratorType>(ndims,
+                                            utils::IteratorType::parallel);
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+                                           type.getRank(), op->getContext())});
+    }
+    return {};
+  }
+
+  // Indicate failure if no result sharding exists.
+  // Otherwise mirror result sharding if it is a tensor constant.
+  // Otherwise return replication option.
+  FailureOr<ShardingOption>
+  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+                    ArrayRef<MeshSharding> resultShardings) const {
+    if (!resultShardings[0]) {
+      return failure();
+    }
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+      for (auto [i, axes] :
+           llvm::enumerate(resultShardings[0].getSplitAxes())) {
+        axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+      }
+      return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+    }
+    return ShardingOption({}, resultShardings[0].getMeshAttr());
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    auto cOp = cast<ConstantOp>(op);
+    auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
+    if (value) {
+      if (!value.isSplat() || !resultShardings[0]) {
+        // Currently non-splat constants are not supported.
+        return failure();
+      }
+      auto sharding = resultShardings[0];
+      auto newType = cast<RankedTensorType>(shardType(
+          cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+          sharding));
+      auto newValue = value.resizeSplat(newType);
+      auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
+      spmdizationMap.map(op->getResult(0), newOp.getResult());
+      spmdizationMap.map(op, newOp.getOperation());
+    } else {
+      // `clone` will populate the mapping of old to new results.
+      (void)builder.clone(*op, spmdizationMap);
+    }
+    return success();
+  }
+};
+} // namespace
+
+void mlir::arith::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+    ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 33460ff25e9e4..352bf476e3f57 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -316,9 +316,13 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
-  Operation *operandOp = operand.getOwner();
   Operation *operandSrcOp = operandValue.getDefiningOp();
   bool isBlockArg = !operandSrcOp;
+  if(!isBlockArg && operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+    return;
+  }
+
+  Operation *operandOp = operand.getOwner();
   ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
 
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -710,8 +714,13 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
 MeshSharding::MeshSharding(Value rhs) {
   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
   assert(shardingOp && "expected sharding op");
-  *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
-              shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+  auto splitAxes = shardingOp.getSplitAxes().getAxes();
+  auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
+  if(splitAxes.empty() && partialAxes.empty()) {
+    *this = MeshSharding();
+    return;
+  }
+  *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
               shardingOp.getPartialType().value_or(ReductionKind::Sum),
               shardingOp.getStaticHaloSizes(),
               shardingOp.getStaticShardedDimsOffsets(),
@@ -727,6 +736,10 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<int64_t> static_sharded_dims_offsets_,
                                ArrayRef<Value> dynamic_halo_sizes_,
                                ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+  if(split_axes_.empty() && partial_axes_.empty()) {
+    return MeshSharding();
+  }
+
   MeshSharding res;
   res.mesh = mesh_;
   res.split_axes.resize(split_axes_.size());
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index c1f4d563d5b42..aae2d4ccfeed9 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -168,16 +168,16 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
 
   // check operands and results type
   for (Type type : op->getOperandTypes())
-    if (!llvm::isa<RankedTensorType>(type))
+    if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
       return failure();
   for (Type type : op->getResultTypes())
-    if (!llvm::isa<RankedTensorType>(type))
+    if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
       return failure();
 
   // check loop types
-  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
-  if (loopTypes.empty())
-    return failure();
+  // SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
+  // if (loopTypes.empty())
+  //   return failure();
 
   // check maps
   SmallVector<AffineMap> maps = getIndexingMaps();
@@ -448,7 +448,12 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
                                            const ShardingOption &shardingOption,
                                            AffineMap map) {
   Value operandValue = opOperand.get();
-  auto operandType = cast<RankedTensorType>(operandValue.getType());
+  auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
+  if(!operandType) {
+    if(operandValue.getType().isIntOrIndexOrFloat())
+      return MeshSharding();
+    return failure();
+  }
   SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4bd3b425219c1..f96d54424a2fe 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -282,11 +282,12 @@ static FailureOr<ShardingOption> selectShardingOption(
 // a `mesh.shard` operation for all remaining operands and results that do not
 // have sharding annotations.
 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
   if (op->hasTrait<OpTrait::IsTerminator>() ||
+      (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
       llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
     return success();
 
-  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
   if (!shardingOp) {
     op->emitOpError() << "sharding interface is not implemented.";
     return failure();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 327ea0991e4e1..04932f11e6b43 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -636,14 +636,6 @@ shardedBlockArgumentTypes(Block &block,
   return res;
 }
 
-void spmdizeTriviallyShardableOperation(Operation &op,
-                                        ArrayRef<Value> spmdizedOperands,
-                                        ArrayRef<MeshSharding> operandShardings,
-                                        ArrayRef<MeshSharding> resultShardings,
-                                        IRMapping &spmdizationMap,
-                                        SymbolTableCollection &symbolTable,
-                                        OpBuilder &builder);
-
 static LogicalResult spmdizeOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
     ArrayRef<MeshSharding> operandShardings,
@@ -697,14 +689,15 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
   std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
   llvm::transform(op.getResults(), std::back_inserter(res),
-                  [](OpResult result) {
+                  [&op](OpResult result) {
                     TypedValue<RankedTensorType> rankedTensor =
                         dyn_cast<TypedValue<RankedTensorType>>(result);
-                    if (!rankedTensor) {
+                    if (!rankedTensor || op.hasTrait<OpTrait::ConstantLike>()) {
+                      return MeshSharding();
+                    }
+                    if (!result.hasOneUse()) {
                       return MeshSharding();
                     }
-
-                    assert(result.hasOneUse());
                     Operation *userOp = *result.getUsers().begin();
                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
                     return MeshSharding(shardOp.getSharding());
@@ -765,6 +758,7 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
                                   SymbolTableCollection &symbolTableCollection,
                                   OpBuilder &builder) {
+
   SmallVector<Location> argLocations;
   llvm::transform(block.getArguments(), std::back_inserter(argLocations),
                   [](BlockArgument arg) { return arg.getLoc(); });
@@ -796,8 +790,12 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
   // Snapshot the original blocks to not mess up the iteration when adding new
   // blocks.
   SmallVector<Block *> originalBlocks;
-  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
-                  [](Block &b) { return &b; });
+  for (Block &b : op.getBlocks()) {
+    if (llvm::any_of(b.getOperations(),
+                     [](Operation &op) { return isa<ShardOp>(op); })) {
+      originalBlocks.push_back(&b);
+    }
+  }
 
   for (Block *block : originalBlocks) {
     if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
@@ -823,10 +821,11 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
       break;
     }
   }
-  assert(returnOp);
-  op.setType(FunctionType::get(op->getContext(),
-                               op.getFunctionBody().front().getArgumentTypes(),
-                               returnOp->getOperandTypes()));
+  if (returnOp) {
+    op.setType(FunctionType::get(
+        op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
+        returnOp->getOperandTypes()));
+  }
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index f3e72abe7516e..6bb5d4a66f39e 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -22,10 +22,10 @@ using namespace mlir::mesh;
 
 namespace {
 
-// Sharding of tensor.empty
-struct EmptyOpShardingInterface
-    : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
-                                              tensor::EmptyOp> {
+// Sharding of tensor.empty/tensor.splat
+template<typename OpTy>
+struct CreatorOpShardingInterface
+    : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
     return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +38,7 @@ struct EmptyOpShardingInterface
     auto type = dyn_cast<RankedTensorType>(val.getType());
     if (!type)
       return {};
-    return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
+    return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -83,7 +83,7 @@ struct EmptyOpShardingInterface
         }
       }
       newOp =
-          builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
+          builder.create<OpTy>(op->getLoc(), shardType, newOperands);
       spmdizationMap.map(op->getResult(0), newOp->getResult(0));
     } else {
       // `clone` will populate the mapping of old to new results.
@@ -100,6 +100,7 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
     DialectRegistry &registry) {
 
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
+    EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
+    SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
new file mode 100644
index 0000000000000..0688e14b1cf72
--- /dev/null
+++ b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
@@ -0,0 +1,17 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
+// RUN:   %s | FileCheck %s
+
+mesh.mesh @mesh4x4(shape = 4x4)
+
+// CHECK-LABEL: func @test_spmdize_constant
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32>
+// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32
+// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+    %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+    %ci = arith.constant 434 : i32
+    return %sharding_annotated_1 : tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir
new file mode 100644
index 0000000000000..19eb340549b0b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
+
+mesh.mesh @mesh4x4(shape = 4x4)
+
+// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
+// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
+func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+    %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+    %ci = arith.constant 43.4e+00 : f32
+    %o1 = tensor.empty() : tensor<1024x1024xf32>
+    %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+    return %res : tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
+// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+    %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+    %ci = arith.constant 43.4e+00 : f32
+    %o1 = tensor.empty() : tensor<1024x1024xf32>
+    %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
+    return %sharding_annotated_1 : tensor<1024x1024xf32>
+}

>From ef6671b6c5638ea3dd12f29877c9512341e3456d Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 4 Dec 2024 11:04:59 +0100
Subject: [PATCH 02/10] better handling of replicated tensors

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  2 +-
 .../Mesh/Interfaces/ShardingInterface.h       |  4 ++-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 22 +++++++++---
 .../Mesh/Interfaces/ShardingInterface.cpp     | 36 +++++++++++--------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  3 +-
 5 files changed, 45 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 210b82151ede4..626f2fcf93b36 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,7 +51,7 @@ class MeshSharding {
   SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
-  MeshSharding() = default;
+  MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
   MeshSharding(Value rhs);
   static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
                           ArrayRef<MeshAxesAttr> split_axes_,
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index b4d25cef05a7b..14aad7f9f6783 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -36,7 +36,9 @@ struct ShardingOption {
   bool empty = false;
   ShardingOption() = default;
   ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
-      : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+      : shardingArray(std::move(shardingArray)), mesh(mesh) {
+    assert(this->mesh);
+  }
   static ShardingOption makeEmpty() {
     auto res = ShardingOption();
     res.empty = true;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 352bf476e3f57..5e342a855d6ae 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
                        ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
+  // 0d tensors cannot be sharded and must get replicated
+  if (inShape.empty()) {
+    assert(outShape.empty());
+    return;
+  }
+
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
@@ -318,7 +324,12 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
   Value operandValue = operand.get();
   Operation *operandSrcOp = operandValue.getDefiningOp();
   bool isBlockArg = !operandSrcOp;
-  if(!isBlockArg && operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
+  {
+    auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
+    assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
+  }
+  if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
+      operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
     return;
   }
 
@@ -711,13 +722,15 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
   return !(*this == rhs);
 }
 
+MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
+
 MeshSharding::MeshSharding(Value rhs) {
   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
   assert(shardingOp && "expected sharding op");
   auto splitAxes = shardingOp.getSplitAxes().getAxes();
   auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
   if(splitAxes.empty() && partialAxes.empty()) {
-    *this = MeshSharding();
+    *this = MeshSharding(shardingOp.getMeshAttr());
     return;
   }
   *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
@@ -736,12 +749,11 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<int64_t> static_sharded_dims_offsets_,
                                ArrayRef<Value> dynamic_halo_sizes_,
                                ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+  MeshSharding res(mesh_);
   if(split_axes_.empty() && partial_axes_.empty()) {
-    return MeshSharding();
+    return res;
   }
 
-  MeshSharding res;
-  res.mesh = mesh_;
   res.split_axes.resize(split_axes_.size());
   for (auto [i, axis] : llvm::enumerate(split_axes_)) {
     res.split_axes[i] =
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index aae2d4ccfeed9..aaffe759b0cef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -286,18 +286,22 @@ mesh::detail::defaultGetShardingOption(Operation *op,
       continue;
     AffineMap map = maps[numOperands + shardingIt.index()];
     anyShardingInResultsOrOperands = true;
-    // Handle the split axes: calculate the corresponding loop index for each
-    // split axes sub-array, and then store the sub-array to
-    // shardingOption[index]
-    for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
-      AffineExpr expr = std::get<0>(it);
-      ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
-      auto dim = cast<AffineDimExpr>(expr);
-      unsigned index = dim.getPosition();
-      visitedLoopIndices.insert(index);
-      if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
-                                    axes, index)))
-        return failure();
+    if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
+      shardingOption.mesh = shardAttr.getMeshAttr();
+    } else {
+      // Handle the split axes: calculate the corresponding loop index for each
+      // split axes sub-array, and then store the sub-array to
+      // shardingOption[index]
+      for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+        AffineExpr expr = std::get<0>(it);
+        ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+        auto dim = cast<AffineDimExpr>(expr);
+        unsigned index = dim.getPosition();
+        visitedLoopIndices.insert(index);
+        if (failed(fillShardingOption(op, shardingOption,
+                                      shardAttr.getMeshAttr(), axes, index)))
+          return failure();
+      }
     }
 
     // Handle the partial axes: at this stage, the exact loop index/indices
@@ -323,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
     if (!shardAttr)
       continue;
 
-    anyShardingInResultsOrOperands = true;
+    anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
     AffineMap map = maps[shardingIt.index()];
     unsigned numDims = map.getNumDims();
 
@@ -454,6 +458,10 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
       return MeshSharding();
     return failure();
   }
+  // 0d tensors cannot be sharded and must get replicated
+  if (operandType.getRank() == 0) {
+    return MeshSharding(shardingOption.mesh);
+  }
   SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
@@ -584,7 +592,7 @@ static bool
 isValueCompatibleWithFullReplicationSharding(Value value,
                                              MeshSharding sharding) {
   if (isa<RankedTensorType>(value.getType())) {
-    return sharding && isFullReplication(sharding);
+    return isFullReplication(sharding);
   }
 
   return !sharding;
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 04932f11e6b43..27297a8be5d06 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -561,7 +561,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
   // If source and destination sharding are the same, no need to do anything.
-  if (sourceSharding == targetSharding) {
+  if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
+                                           isFullReplication(targetSharding))) {
     return sourceShard;
   }
 

>From 6a633617bdfc654817b7d2f1d5cb13d84bb84d1a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 5 Dec 2024 12:55:23 +0100
Subject: [PATCH 03/10] canonicalize ShardOp and ShardingOp

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  3 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 87 +++++++++++++++++--
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  2 +-
 mlir/test/Dialect/Mesh/canonicalization.mlir  | 40 ++++++++-
 4 files changed, 122 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 6039e61a93fad..531020930768e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
     Op<Mesh_Dialect, mnemonic, traits> {
 }
 
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
   let summary = "Description of a device/process mesh.";
   let description = [{
     The mesh.mesh operation is a symbol operation that identifies a specific
@@ -460,6 +460,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
       (`annotate_for_users` $annotate_for_users^)?
       attr-dict `:` type($result)
   }];
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 5e342a855d6ae..6a1498c0f6814 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -594,9 +594,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 namespace {
 // Sharding annotations "halo sizes" and "sharded dims offsets"
 // are a mix of attributes and dynamic values. This canonicalization moves
-// constant values to the respective attribute lists and so minimizes the number
+// constant values to the respective attribute lists, minimizing the number
 // of values.
-class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+// It also removes sharded_dims_sizes and halos if they are effectively "empty".
+class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
 public:
   using OpRewritePattern<ShardingOp>::OpRewritePattern;
 
@@ -608,14 +609,39 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
                                     op.getDynamicShardedDimsOffsets(), b);
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
-      return failure();
-    }
+    bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
+                    succeeded(foldDynamicIndexList(mixedOffs, true));
 
     auto halos = decomposeMixedValues(mixedHalos);
     auto offs = decomposeMixedValues(mixedOffs);
 
+    if (halos.second.empty() && !halos.first.empty()) {
+      if (halos.first[0] == 0 && llvm::all_equal(halos.first)) {
+        halos.first.clear();
+        modified = true;
+      }
+    }
+
+    if (offs.second.empty() && !offs.first.empty()) {
+      assert(offs.first.size() >= 2);
+      auto diff = offs.first[1] - offs.first[0];
+      bool all_same = offs.first.size() > 2;
+      for (auto i = 2u; i < offs.first.size(); ++i) {
+        if (offs.first[i] - offs.first[i - 1] != diff) {
+          all_same = false;
+          break;
+        }
+      }
+      if (all_same) {
+        offs.first.clear();
+        modified = true;
+      }
+    }
+
+    if (!modified) {
+      return failure();
+    }
+
     op.setStaticHaloSizes(halos.first);
     op.getDynamicHaloSizesMutable().assign(halos.second);
     op.setStaticShardedDimsOffsets(offs.first);
@@ -628,7 +654,7 @@ class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
 
 void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
                                              mlir::MLIRContext *context) {
-  results.add<FoldDynamicLists>(context);
+  results.add<NormalizeSharding>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -796,6 +822,53 @@ void ShardOp::getAsmResultNames(
   setNameFn(getResult(), "sharding_annotated");
 }
 
+namespace {
+// Determine if the given ShardOp is a duplicate of another ShardOp
+// on the same value. This can happen if constant values are sharded.
+class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
+public:
+  using OpRewritePattern<ShardOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
+    // Get the use-list of the value being sharded and check if it has more than
+    // one use.
+    Value value = op.getSrc();
+    if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
+      return failure();
+    }
+
+    // Iterate through the uses of the value to find a duplicate ShardOp.
+    for (auto &use : value.getUses()) {
+      if (use.getOwner() != op.getOperation()) {
+        auto otherOp = dyn_cast<ShardOp>(use.getOwner());
+        if (!otherOp || !otherOp->isBeforeInBlock(op)) {
+          return failure();
+        }
+        // Create a MeshSharding object for the current and the other ShardOp
+        // If the two are equal replace current op with the other op.
+        MeshSharding currentSharding(op.getSharding());
+        MeshSharding otherSharding(otherOp.getSharding());
+        if (currentSharding == otherSharding) {
+          b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
+          b.eraseOp(op.getOperation());
+        } else {
+          // use the other sharding as input for op
+          op.getSrcMutable().assign(otherOp.getResult());
+        }
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                          mlir::MLIRContext *context) {
+  results.add<FoldDuplicateShardOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_multi_index op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 27297a8be5d06..e6fe0fd5d1e87 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -693,7 +693,7 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
                   [&op](OpResult result) {
                     TypedValue<RankedTensorType> rankedTensor =
                         dyn_cast<TypedValue<RankedTensorType>>(result);
-                    if (!rankedTensor || op.hasTrait<OpTrait::ConstantLike>()) {
+                    if (!rankedTensor) {
                       return MeshSharding();
                     }
                     if (!result.hasOneUse()) {
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index f0112d689805d..aff07bbf8a214 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -207,4 +207,42 @@ func.func @test_shard_offs() -> !mesh.sharding {
   // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
   %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
   return %sharding : !mesh.sharding
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops
+func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
+  %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+  %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+  %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+  %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+  %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+  %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+  // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+  return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops_diff
+func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
+  %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+  %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
+  %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
+  // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
+  %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+  %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+  %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
+  %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+  // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+  return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}

>From d52ca9a108e7b3e4d3d3dd42d7482e5162035df5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 19 Dec 2024 13:19:32 +0100
Subject: [PATCH 04/10] sharding propagation: add only one shardop for each
 result

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h |  9 ++++--
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp        | 33 +++++++++++++--------
 2 files changed, 26 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 626f2fcf93b36..7de7842baf98a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
+// Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                         OpOperand &operand,
-                                         OpBuilder &builder);
+// Return the target ShardOP (new or existing).
+ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                            OpOperand &operand,
+                                            OpBuilder &builder,
+                                            ShardOp newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6a1498c0f6814..2fff67c44a8ac 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -275,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                     OpOperand &operand,
-                                                     OpBuilder &builder) {
+ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                                        OpOperand &operand,
+                                                        OpBuilder &builder,
+                                                        ShardOp newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
   Operation *operandOp = operand.getOwner();
@@ -286,13 +287,16 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
     // No need for anything the correct sharding is already set.
-    return;
+    return newShardOp ? newShardOp : shardOp;
   }
 
-  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
-  auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
-                              /*annotate_for_users*/ false);
+  if (!newShardOp) {
+    auto shardingOp =
+        builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+    newShardOp =
+        builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+                                /*annotate_for_users*/ false);
+  }
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -300,20 +304,23 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
       });
 
   if (!shardOp || shardOp.getAnnotateForUsers()) {
-    return;
+    return newShardOp;
   }
 
-  auto newShardOp2 =
-      builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
-                              /*annotate_for_users*/ true);
+  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+                                             newShardOp.getSharding(),
+                                             /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  return newShardOp;
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
+  ShardOp newShardOp;
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+    newShardOp =
+        maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
   }
 }
 

>From 3c76df3d4552953b0d5fa6719b31d68796fda199 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 13 Jan 2025 15:36:54 +0100
Subject: [PATCH 05/10] Adding sharding extraction operation and op tests and
 handling GetShardingOp in ShardingPropagation

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 21 ++++++++++++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 22 ++++++++++++++-----
 .../Mesh/Transforms/ShardingPropagation.cpp   |  2 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  9 ++++++++
 mlir/test/Dialect/Mesh/ops.mlir               | 10 +++++++++
 mlir/test/Dialect/Mesh/spmdization.mlir       | 14 ++++++++++++
 6 files changed, 71 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 531020930768e..031e6f63bcb42 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
+    OpBuilder<(ins "llvm::StringRef":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
+    )>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
+def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+  let summary = "Get the sharding of the given tensor.";
+  let description = [{
+    This operation returns the sharding of the given tensor as a MeshSharding.
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$source
+  );
+  let results = (outs
+    Mesh_Sharding:$result
+  );
+  let assemblyFormat = [{
+    $source attr-dict `:` type($source) `->` type($result)
+  }];
+}
+
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
   let summary = "Get the shard shape of a given process/device.";
   let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 2fff67c44a8ac..f84d467048522 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxesAttr> split_axes,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
-                       ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_offsets) {
+                       ArrayRef<int64_t> static_halos,
+                       ArrayRef<int64_t> static_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
-                                     static_sharded_dims_offsets),
-      {});
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
 }
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
       {}, {}, {}, {});
 }
 
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+                       llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+                       ArrayRef<int64_t> static_halos,
+                       ArrayRef<int64_t> static_offsets) {
+  return build(
+      b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+      MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+}
+
 void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index f96d54424a2fe..8c989cce63406 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -285,7 +285,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
   if (op->hasTrait<OpTrait::IsTerminator>() ||
       (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
-      llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
+      llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
     return success();
 
   if (!shardingOp) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e6fe0fd5d1e87..4ec8bbc0dff7d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -738,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
   if (isa<ShardingOp>(op)) {
     return success();
   }
+  if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
+    auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp) {
+      return op.emitError("expected a shard op as source of get_sharding");
+    }
+    auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
+    spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
+    return success();
+  }
 
   ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
   if (shardOp) {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 978de4939ee77..dae21655afb23 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -164,6 +164,16 @@ func.func @mesh_shard_shape() {
   return
 }
 
+// CHECK-LABEL: func @mesh_get_sharding
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
+  %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
+  // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
+  %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
+  return %0 : !mesh.sharding
+}
+
 // CHECK-LABEL: func @mesh_shape
 func.func @mesh_shape() -> (index, index) {
   // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index c1b96fda0f4a7..59f7162e21013 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -4,6 +4,20 @@
 
 mesh.mesh @mesh_1d(shape = 2)
 
+// CHECK-LABEL: func @return_sharding
+func.func @return_sharding(
+  // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
+  %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
+) -> (tensor<2xf32>, !mesh.sharding) {
+  %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  : tensor<2xf32>
+  // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
+  %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
+  // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
+  return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
+}
+
 // CHECK-LABEL: func @full_replication
 func.func @full_replication(
   // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>

>From 508095ac58bd85dfd6a0cdc7bad93ef57fbc8610 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 28 Jan 2025 10:39:46 +0100
Subject: [PATCH 06/10] comments

adding libs

clang-format

renaming mesh-spmdize.cpp -> mesh-spmdize.mlir and fixing format
---
 .../Dialect/Arith/Transforms/CMakeLists.txt    |  2 ++
 .../Arith/Transforms/ShardingInterfaceImpl.cpp |  2 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp           |  9 ++++++---
 .../Mesh/Interfaces/ShardingInterface.cpp      |  9 ++-------
 .../Dialect/Mesh/Transforms/Spmdization.cpp    |  2 +-
 .../Extensions/MeshShardingExtensions.cpp      | 18 +++++++++++-------
 mlir/test/Dialect/Arith/mesh-spmdize.cpp       | 17 -----------------
 mlir/test/Dialect/Arith/mesh-spmdize.mlir      | 17 +++++++++++++++++
 8 files changed, 40 insertions(+), 36 deletions(-)
 delete mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.cpp
 create mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.mlir

diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 30dd84aff120f..f96bda603baa6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -27,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
   MLIRInferIntRangeInterface
   MLIRIR
   MLIRMemRefDialect
+  MLIRMeshDialect
   MLIRPass
+  MLIRShardingInterface
   MLIRTensorDialect
   MLIRTransforms
   MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index fc033294eb01b..f31db49067756 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -19,7 +19,7 @@ using namespace mlir::mesh;
 
 namespace {
 
-// Sharding of arith.empty/arith.splat
+// Sharding of arith.constant
 struct ConstantShardingInterface
     : public ShardingInterface::ExternalModel<ConstantShardingInterface,
                                               ConstantOp> {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f84d467048522..c789fc527e3f6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -286,7 +286,7 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
-    // No need for anything the correct sharding is already set.
+    // No need for anything if the correct sharding is already set.
     return newShardOp ? newShardOp : shardOp;
   }
 
@@ -639,6 +639,8 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
       }
     }
 
+    // Remove sharded dims offsets if they are effectively the default values,
+    // e.g. if they define equi-distance between all neighboring shards.
     if (offs.second.empty() && !offs.first.empty()) {
       assert(offs.first.size() >= 2);
       auto diff = offs.first[1] - offs.first[0];
@@ -772,7 +774,8 @@ MeshSharding::MeshSharding(Value rhs) {
   assert(shardingOp && "expected sharding op");
   auto splitAxes = shardingOp.getSplitAxes().getAxes();
   auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
-  if(splitAxes.empty() && partialAxes.empty()) {
+  // If splitAxes and partialAxes are empty, use "empty" constructor.
+  if (splitAxes.empty() && partialAxes.empty()) {
     *this = MeshSharding(shardingOp.getMeshAttr());
     return;
   }
@@ -793,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<Value> dynamic_halo_sizes_,
                                ArrayRef<Value> dynamic_sharded_dims_offsets_) {
   MeshSharding res(mesh_);
-  if(split_axes_.empty() && partial_axes_.empty()) {
+  if (split_axes_.empty() && partial_axes_.empty()) {
     return res;
   }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index aaffe759b0cef..f427d004c558f 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -174,11 +174,6 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
     if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
       return failure();
 
-  // check loop types
-  // SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
-  // if (loopTypes.empty())
-  //   return failure();
-
   // check maps
   SmallVector<AffineMap> maps = getIndexingMaps();
   if (maps.empty())
@@ -453,8 +448,8 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
                                            AffineMap map) {
   Value operandValue = opOperand.get();
   auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
-  if(!operandType) {
-    if(operandValue.getType().isIntOrIndexOrFloat())
+  if (!operandType) {
+    if (operandValue.getType().isIntOrIndexOrFloat())
       return MeshSharding();
     return failure();
   }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 4ec8bbc0dff7d..601af0200e785 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -690,7 +690,7 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
   std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
   llvm::transform(op.getResults(), std::back_inserter(res),
-                  [&op](OpResult result) {
+                  [](OpResult result) {
                     TypedValue<RankedTensorType> rankedTensor =
                         dyn_cast<TypedValue<RankedTensorType>>(result);
                     if (!rankedTensor) {
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 6bb5d4a66f39e..b2acbf20b3fb9 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -23,9 +23,10 @@ using namespace mlir::mesh;
 namespace {
 
 // Sharding of tensor.empty/tensor.splat
-template<typename OpTy>
+template <typename OpTy>
 struct CreatorOpShardingInterface
-    : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
+    : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
+                                              OpTy> {
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
     return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
     auto type = dyn_cast<RankedTensorType>(val.getType());
     if (!type)
       return {};
-    return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
+    return SmallVector<AffineMap>(
+        op->getNumOperands() + op->getNumResults(),
+        {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
           newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
         }
       }
-      newOp =
-          builder.create<OpTy>(op->getLoc(), shardType, newOperands);
+      newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
       spmdizationMap.map(op->getResult(0), newOp->getResult(0));
     } else {
       // `clone` will populate the mapping of old to new results.
@@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
     DialectRegistry &registry) {
 
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
-    SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
+    EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
+        *ctx);
+    SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
+        *ctx);
   });
 }
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
deleted file mode 100644
index 0688e14b1cf72..0000000000000
--- a/mlir/test/Dialect/Arith/mesh-spmdize.cpp
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt \
-// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
-// RUN:   %s | FileCheck %s
-
-mesh.mesh @mesh4x4(shape = 4x4)
-
-// CHECK-LABEL: func @test_spmdize_constant
-// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32>
-// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32
-// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
-    %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
-    %ci = arith.constant 434 : i32
-    return %sharding_annotated_1 : tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir
new file mode 100644
index 0000000000000..6b55dd533a92c
--- /dev/null
+++ b/mlir/test/Dialect/Arith/mesh-spmdize.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
+// RUN:   %s | FileCheck %s
+
+mesh.mesh @mesh4x4(shape = 4x4)
+
+// CHECK-LABEL: func @test_spmdize_constant
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
+// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
+// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+  %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+  %ci = arith.constant 434 : i32
+  return %sharding_annotated_1 : tensor<1024x1024xf32>
+}

>From 818baedf889c0a48f7a4a9c1879a1b7f14230597 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 11 Feb 2025 12:46:44 +0100
Subject: [PATCH 07/10] assert expected ArrayRerf argument size

---
 .../Arith/Transforms/ShardingInterfaceImpl.cpp  | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index f31db49067756..ff1625877efcb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -46,18 +46,20 @@ struct ConstantShardingInterface
   FailureOr<ShardingOption>
   getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
                     ArrayRef<MeshSharding> resultShardings) const {
-    if (!resultShardings[0]) {
+    assert(resultShardings.size() == 1 &&
+           "Expecting exactly one result sharding for arith.constant");
+    auto resultSharding = resultShardings[0];
+    if (!resultSharding) {
       return failure();
     }
     if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
-      ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
-      for (auto [i, axes] :
-           llvm::enumerate(resultShardings[0].getSplitAxes())) {
+      ShardingArray axesArray(resultSharding.getSplitAxes().size());
+      for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
         axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
       }
-      return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+      return ShardingOption(axesArray, resultSharding.getMeshAttr());
     }
-    return ShardingOption({}, resultShardings[0].getMeshAttr());
+    return ShardingOption({}, resultSharding.getMeshAttr());
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -67,8 +69,7 @@ struct ConstantShardingInterface
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
     auto cOp = cast<ConstantOp>(op);
-    auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
-    if (value) {
+    if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
       if (!value.isSplat() || !resultShardings[0]) {
         // Currently non-splat constants are not supported.
         return failure();

>From 5702b45e017b7d2a306cac702859522a4ff77610 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 11 Feb 2025 13:12:46 +0100
Subject: [PATCH 08/10] added sharding exmpample

---
 mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index ff1625877efcb..62d137a4cfb0e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -20,6 +20,11 @@ using namespace mlir::mesh;
 namespace {
 
 // Sharding of arith.constant
+// RankedTensor constants can be sharded like any other tensor.
+//   %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+//   %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+// Scalar constants are always replicated and need no sharding annotation.
+
 struct ConstantShardingInterface
     : public ShardingInterface::ExternalModel<ConstantShardingInterface,
                                               ConstantOp> {

>From 405e4b04b0213fbb89594b4b26385c9e7b95c898 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 12 Feb 2025 11:01:57 +0100
Subject: [PATCH 09/10] comments and nicer code (from review)

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 35 +++++++++++++++-------------
 mlir/test/Dialect/Mesh/ops.mlir      |  2 --
 2 files changed, 19 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c789fc527e3f6..561b1ef3b1c39 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -629,30 +629,33 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
     bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
                     succeeded(foldDynamicIndexList(mixedOffs, true));
 
-    auto halos = decomposeMixedValues(mixedHalos);
-    auto offs = decomposeMixedValues(mixedOffs);
+    auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
+    auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
 
-    if (halos.second.empty() && !halos.first.empty()) {
-      if (halos.first[0] == 0 && llvm::all_equal(halos.first)) {
-        halos.first.clear();
+    if (dynamicHalos.empty() && !staticHalos.empty()) {
+      if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
+        staticHalos.clear();
         modified = true;
       }
     }
 
     // Remove sharded dims offsets if they are effectively the default values,
     // e.g. if they define equi-distance between all neighboring shards.
-    if (offs.second.empty() && !offs.first.empty()) {
-      assert(offs.first.size() >= 2);
-      auto diff = offs.first[1] - offs.first[0];
-      bool all_same = offs.first.size() > 2;
-      for (auto i = 2u; i < offs.first.size(); ++i) {
-        if (offs.first[i] - offs.first[i - 1] != diff) {
+    // Requires static-only offsets. Compares the first distance as the
+    // difference between the first two offsets. Only if all consecutive
+    // distances are the same, the offsets are removed.
+    if (dynamicOffs.empty() && !staticOffs.empty()) {
+      assert(staticOffs.size() >= 2);
+      auto diff = staticOffs[1] - staticOffs[0];
+      bool all_same = staticOffs.size() > 2;
+      for (auto i = 2u; i < staticOffs.size(); ++i) {
+        if (staticOffs[i] - staticOffs[i - 1] != diff) {
           all_same = false;
           break;
         }
       }
       if (all_same) {
-        offs.first.clear();
+        staticOffs.clear();
         modified = true;
       }
     }
@@ -661,10 +664,10 @@ class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
       return failure();
     }
 
-    op.setStaticHaloSizes(halos.first);
-    op.getDynamicHaloSizesMutable().assign(halos.second);
-    op.setStaticShardedDimsOffsets(offs.first);
-    op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+    op.setStaticHaloSizes(staticHalos);
+    op.getDynamicHaloSizesMutable().assign(dynamicHalos);
+    op.setStaticShardedDimsOffsets(staticOffs);
+    op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
 
     return success();
   }
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index dae21655afb23..43a75bf3d8040 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -167,8 +167,6 @@ func.func @mesh_shard_shape() {
 // CHECK-LABEL: func @mesh_get_sharding
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
-  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
-  %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
   // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
   %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
   return %0 : !mesh.sharding

>From fc92f0c58d11f896a013bc1e6f29f75fb72cc90a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 12 Feb 2025 11:58:25 +0100
Subject: [PATCH 10/10] maybeInsertTargetShardingAnnotation accepting reference
 only

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h |  9 ++++-----
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp        | 20 +++++++++++---------
 2 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 7de7842baf98a..fc5cfffea27a7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -203,11 +203,10 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 // Insert shard op if there is not one that already has the same sharding.
 // Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
-// Return the target ShardOP (new or existing).
-ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                            OpOperand &operand,
-                                            OpBuilder &builder,
-                                            ShardOp newShardOp);
+// Potentially updates newShardOp.
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                         OpOperand &operand, OpBuilder &builder,
+                                         ShardOp &newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 561b1ef3b1c39..12e1ec6d717ea 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -275,10 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                        OpOperand &operand,
-                                                        OpBuilder &builder,
-                                                        ShardOp newShardOp) {
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder,
+                                                     ShardOp &newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
   Operation *operandOp = operand.getOwner();
@@ -287,7 +287,10 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
     // No need for anything if the correct sharding is already set.
-    return newShardOp ? newShardOp : shardOp;
+    if (!newShardOp) {
+      newShardOp = shardOp;
+    }
+    return;
   }
 
   if (!newShardOp) {
@@ -304,14 +307,14 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
       });
 
   if (!shardOp || shardOp.getAnnotateForUsers()) {
-    return newShardOp;
+    return;
   }
 
   auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
                                              newShardOp.getSharding(),
                                              /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
-  return newShardOp;
+  return;
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
@@ -319,8 +322,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpBuilder &builder) {
   ShardOp newShardOp;
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    newShardOp =
-        maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
+    maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
   }
 }
 



More information about the Mlir-commits mailing list