[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

Frank Schlimbach llvmlistbot at llvm.org
Tue Feb 11 03:47:12 PST 2025


================
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.empty/arith.splat
+struct ConstantShardingInterface
+    : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+                                              ConstantOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto ndims = 0;
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ndims = type.getRank();
+    }
+    return SmallVector<utils::IteratorType>(ndims,
+                                            utils::IteratorType::parallel);
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+                                           type.getRank(), op->getContext())});
+    }
+    return {};
+  }
+
+  // Indicate failure if no result sharding exists.
+  // Otherwise mirror result sharding if it is a tensor constant.
+  // Otherwise return replication option.
+  FailureOr<ShardingOption>
+  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+                    ArrayRef<MeshSharding> resultShardings) const {
+    if (!resultShardings[0]) {
+      return failure();
+    }
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
+      for (auto [i, axes] :
+           llvm::enumerate(resultShardings[0].getSplitAxes())) {
+        axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
+      }
+      return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
+    }
+    return ShardingOption({}, resultShardings[0].getMeshAttr());
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    auto cOp = cast<ConstantOp>(op);
+    auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
----------------
fschlimb wrote:

done


https://github.com/llvm/llvm-project/pull/124724


More information about the Mlir-commits mailing list