[Mlir-commits] [mlir] [Mesh] initialize mesh dialect (PR #68007)

Chengji Yao llvmlistbot at llvm.org
Tue Oct 3 08:38:09 PDT 2023


================
@@ -0,0 +1,122 @@
+//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "mesh-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Mesh dialect
+//===----------------------------------------------------------------------===//
+
+void MeshDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+      >();
+}
+
+Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
+                                            Type type, Location loc) {
+  return arith::ConstantOp::materialize(builder, value, type, loc);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.cluster op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ClusterOp::verify() {
+  ArrayRef<int64_t> dimSizes = getDimSizes();
+  size_t rank = getRank();
+
+  if (rank == 0)
+    return emitOpError("rank of cluster is expected to be a positive integer");
+
+  if (dimSizes.size() > rank)
+    return emitOpError(
+        "rank of dim_sizes is not expected to be larger than rank of cluster");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ShardOp::verify() {
+  bool asResult = getAsResult();
+  if (asResult) {
+    Value src = getSrc();
+    Operation *defOp = src.getDefiningOp();
+    if (llvm::isa_and_nonnull<ShardOp>(defOp))
+      return emitOpError("two mesh.shard ops with as_result = true are not "
+                         "expected to be stacked together");
+
+    unsigned numShard = llvm::count_if(src.getUsers(), [](Operation *user) {
+      return llvm::isa<ShardOp>(user);
+    });
+    if (numShard > 1)
+      return emitOpError(
+          "when than one mesh.shard ops operate on the same tensor, all of "
+          "their as_result attributes are expected to be false");
+
+  } else {
+    ShardOp defShardOp = getSrc().getDefiningOp<ShardOp>();
+    if (defShardOp && !defShardOp.getAsResult())
+      return emitOpError("two mesh.shard ops with as_result = false are not "
+                         "expected to be stacked together");
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                         SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> axes) {
+  // TODO: At present cluster symbol ref is not verified. This is due to the
+  // difficulty in fetching the corresponding symbol op based on an attribute.
+
+  DenseSet<int64_t> visitedAxes;
----------------
yaochengji wrote:

Updated.

Changed to `int8_t`, as well as the `axes` member in `MeshSharding`. The reasons why it's `int8_t` not your suggested `uint8_t` are:
1. At present there's no DenseArrayAttr of uint8_t type. And I don't choose to add the new  DenseArrayAttr of uint8_t type to reduce the binary size.
2. If the type is `uint8_t`, and a negative number is passed, it cannot be detected during verification.

https://github.com/llvm/llvm-project/pull/68007


More information about the Mlir-commits mailing list