[Mlir-commits] [mlir] [Mesh] initialize mesh dialect (PR #68007)
Chengji Yao
llvmlistbot at llvm.org
Tue Oct 3 08:38:09 PDT 2023
================
@@ -0,0 +1,122 @@
+//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "mesh-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Mesh dialect
+//===----------------------------------------------------------------------===//
+
+void MeshDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+ >();
+}
+
+Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ return arith::ConstantOp::materialize(builder, value, type, loc);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.cluster op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ClusterOp::verify() {
+ ArrayRef<int64_t> dimSizes = getDimSizes();
+ size_t rank = getRank();
+
+ if (rank == 0)
+ return emitOpError("rank of cluster is expected to be a positive integer");
+
+ if (dimSizes.size() > rank)
+ return emitOpError(
+ "rank of dim_sizes is not expected to be larger than rank of cluster");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ShardOp::verify() {
+ bool asResult = getAsResult();
+ if (asResult) {
+ Value src = getSrc();
+ Operation *defOp = src.getDefiningOp();
+ if (llvm::isa_and_nonnull<ShardOp>(defOp))
+ return emitOpError("two mesh.shard ops with as_result = true are not "
+ "expected to be stacked together");
+
+ unsigned numShard = llvm::count_if(src.getUsers(), [](Operation *user) {
+ return llvm::isa<ShardOp>(user);
+ });
+ if (numShard > 1)
+ return emitOpError(
+ "when than one mesh.shard ops operate on the same tensor, all of "
+ "their as_result attributes are expected to be false");
+
+ } else {
+ ShardOp defShardOp = getSrc().getDefiningOp<ShardOp>();
+ if (defShardOp && !defShardOp.getAsResult())
+ return emitOpError("two mesh.shard ops with as_result = false are not "
+ "expected to be stacked together");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> axes) {
+ // TODO: At present cluster symbol ref is not verified. This is due to the
+ // difficulty in fetching the corresponding symbol op based on an attribute.
+
+ DenseSet<int64_t> visitedAxes;
----------------
yaochengji wrote:
Updated.
Changed to `int8_t`, as well as the `axes` member in `MeshSharding`. The reasons why it's `int8_t` not your suggested `uint8_t` are:
1. At present there's no DenseArrayAttr of uint8_t type. And I don't choose to add the new DenseArrayAttr of uint8_t type to reduce the binary size.
2. If the type is `uint8_t`, and a negative number is passed, it cannot be detected during verification.
https://github.com/llvm/llvm-project/pull/68007
More information about the Mlir-commits
mailing list