[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)

Boian Petkantchin llvmlistbot at llvm.org
Mon Oct 23 15:05:40 PDT 2023


================
@@ -0,0 +1,529 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+static FailureOr<MeshShardingAttr>
+getMeshShardingAttr(OpResult result, bool useOperandSharding) {
+  Value val = result.cast<Value>();
+  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return !shardOp.getAnnotateForUsers();
+  });
+
+  if (anyShardedForDef) {
+    // expected to have exact one use if it has a use of `mesh.shard` without
+    // unit attr annotate_for_users
+    if (!val.hasOneUse())
+      return failure();
+    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+    return shardOp.getShard();
+  } else if (useOperandSharding) {
+    bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+      auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+      if (!shardOp)
+        return false;
+      return shardOp.getAnnotateForUsers();
+    });
+    if (anyShardedForUsers) {
+      SmallVector<ShardOp> shardOps;
+      for (Operation *user : val.getUsers()) {
+        ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+        if (shardOp)
+          shardOps.push_back(shardOp);
+      }
+      MeshShardingAttr shardForDef = shardOps[0].getShard();
+      for (size_t i = 1; i < shardOps.size(); ++i) {
+        // TODO: Deduce a reasonable mesh sharding attr for def when they are
+        // different
+        assert(shardOps[i].getShard() == shardForDef &&
+               "only support all shard ops have the same mesh sharding attr");
+      }
+      return shardForDef;
+    }
+  }
+
+  return failure();
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation operand.
+static FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand) {
+  Value val = opOperand.get();
+  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
+    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+
+  return failure();
+}
+
+static LogicalResult
+checkOperandAffineExprRecursively(AffineExpr expr,
+                                  SmallVectorImpl<bool> &seenIds) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Add: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
+      return failure();
+    if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
+      return failure();
+    return success();
+  }
+  case AffineExprKind::Mul: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    AffineExpr dimExpr;
+    if (lhs.getKind() == AffineExprKind::DimId) {
+      dimExpr = lhs;
+      if (rhs.getKind() != AffineExprKind::Constant)
+        return failure();
+    } else if (rhs.getKind() == AffineExprKind::DimId &&
+               lhs.getKind() == AffineExprKind::Constant) {
+      dimExpr = rhs;
+    } else
+      return failure();
+    unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  case AffineExprKind::DimId: {
+    unsigned position = expr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  default:
+    return failure();
+  }
+}
+
+static FailureOr<llvm::SmallSet<unsigned, 2>>
+checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
+  SmallVector<bool> seenIds(numDims, false);
+  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
+    return failure();
+
+  llvm::SmallSet<unsigned, 2> positions;
+  for (auto it : llvm::enumerate(seenIds)) {
+    if (it.value())
+      positions.insert((unsigned)it.index());
+  }
+  return positions;
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
+LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+  Operation *op = getOperation();
+
+  // check operands and results type
+  for (Type type : op->getOperandTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+  for (Type type : op->getResultTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+
+  // check loop types
+  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  if (loopTypes.size() == 0)
+    return failure();
+
+  // check maps
+  SmallVector<AffineMap> maps = getIndexingMaps();
+  if (maps.size() == 0)
+    return failure();
+  unsigned numOperands = op->getNumOperands();
+  unsigned numResults = op->getNumResults();
+  if (numOperands + numResults != maps.size())
+    return failure();
+
+  for (OpResult result : op->getResults()) {
+    auto resultType = result.getType().dyn_cast<RankedTensorType>();
+    if (!resultType)
+      return failure();
+    AffineMap map = maps[numOperands + result.getResultNumber()];
+    if (!map.isProjectedPermutation()) {
----------------
sogartar wrote:

I think in general projected permutations are too restrictive. There are operations like [convolution](https://github.com/llvm/llvm-project/blob/b997ff41c11cc69cfcb6c8a3ed39ed47229cf891/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L2295) that have indexing maps that are not projected permutation.
My understanding that after decomposition and before SPMDization you would want to have only ops with indexing maps that are projected permutations.

https://github.com/llvm/llvm-project/pull/69665


More information about the Mlir-commits mailing list