[llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)

Lei Zhang via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 22:41:09 PST 2024


================
@@ -0,0 +1,336 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+static ReductionKind getReductionKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+      // Floating-point operations.
+      .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+      .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+      // Integer operations.
+      .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+      .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+      .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+      .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getReductionOp(LinalgOp op) {
----------------
antiagainst wrote:

I'd call it `getCombinerOp` to be consistent.

https://github.com/llvm/llvm-project/pull/82284


More information about the llvm-commits mailing list