[llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)

Lei Zhang via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 22:41:11 PST 2024


================
@@ -0,0 +1,336 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
----------------
antiagainst wrote:

`// Returns the corresponding mesh reduction kind for the given arith op`

https://github.com/llvm/llvm-project/pull/82284


More information about the llvm-commits mailing list