[llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)

Lei Zhang via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 22:41:10 PST 2024

@@ -0,0 +1,336 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+namespace mlir::linalg {
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+static ReductionKind getReductionKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+      // Floating-point operations.
+      .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+      .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+      // Integer operations.
+      .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+      .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+      .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+      .Default([](Operation *op) { return ReductionKind::Generic; });
+static std::optional<Operation *> getReductionOp(LinalgOp op) {
+  SmallVector<Operation *> combinerOps;
+  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+  if (!reducedValue || combinerOps.size() != 1) {
+    return std::nullopt;
+  }
+  return combinerOps[0];
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+  std::optional<Operation *> reductionOp = getReductionOp(op);
+  if (!reductionOp) {
+    return ReductionKind::Generic;
+  }
+  return getReductionKind(reductionOp.value());
+static MeshOp getMesh(Operation *op,
+                      ArrayRef<MeshShardingAttr> operandShardings,
+                      ArrayRef<MeshShardingAttr> resultShardings,
+                      SymbolTableCollection &symbolTable) {
+  for (MeshShardingAttr sharding : operandShardings) {
+    if (sharding) {
+      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+    }
+  }
+  for (MeshShardingAttr sharding : resultShardings) {
+    if (sharding) {
+      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+    }
+  }
+  assert(false);
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+    LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+    MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+      meshOp.getSymName(), reductionMeshAxes, builder);
+  Value zero = builder.create<arith::ConstantIndexOp>(0);
+  Value isLeadProcess = builder.create<arith::CmpIOp>(
+      builder.getI1Type(), arith::CmpIPredicate::eq,
+      processLinearIndexInReductionGroup, zero);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+                                             isLeadProcess, true, true);
+  // Then block.
+  {
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+    builder.create<scf::YieldOp>(spmdizedOperand);
+  }
+  // Else block.
+  {
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+    SmallVector<OpFoldResult> shape =
+        tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+    PartialReductionOpInterface partialReductionIface =
+        llvm::cast<PartialReductionOpInterface>(op.getOperation());
+    FailureOr<Operation *> reductionNeutralTensorOp =
+        partialReductionIface.generateInitialTensorForPartialReduction(
+            builder, builder.getLoc(), shape, {});
+    assert(succeeded(reductionNeutralTensorOp));
+    builder.create<scf::YieldOp>(
+        reductionNeutralTensorOp.value()->getResult(0));
+  }
+  return ifOp.getResult(0);
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+    LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+    ImplicitLocOpBuilder &builder) {
+  // TODO: add support for multiple destination passing style initial value
+  // operands.
+  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+  // needs to also support multiple DPS initial operands.
+  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+  Value spmdizedInitOperand =
+      spmdizationMap.lookup(op->getOperands()[operandIdx]);
+  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+      op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+  return newOperands;
+static void createAllReduceForResultWithoutPartialSharding(
+    Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+    MeshShardingAttr resultSharding, ReductionKind reductionKind,
+    IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+  SmallVector<MeshAxis> allReduceMeshAxes;
+  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+                [&resultSharding](MeshAxis axis) {
+                  return !llvm::is_contained(resultSharding.getPartialAxes(),
+                                             axis);
+                });
+  if (allReduceMeshAxes.empty()) {
+    return;
+  }
+  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+  Value reducedValue = builder.create<mesh::AllReduceOp>(
+      spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+      allReduceMeshAxes, reductionKind);
+  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+static void createAllReduceForResultsWithoutPartialShardings(
+    LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ImplicitLocOpBuilder &builder) {
+  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+  for (auto [unshardedLinalgOpResult, resultSharding] :
+       llvm::zip(unshardedOp->getResults(), resultShardings)) {
+    createAllReduceForResultWithoutPartialSharding(
+        unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+        reductionKind, spmdizationMap, builder);
+  }
+static void spmdizeLinalgOpWithShardedReduction(
+    LinalgOp op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+    IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+    ImplicitLocOpBuilder &builder) {
+  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+  SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+      loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+  SmallVector<Value> spmdizedLinalgOpOperands =
+      createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+                                                reductionMeshAxes,
+                                                spmdizationMap, builder);
+  // We must not change the operand mappings of the original spmdizationMap as
+  // they are the mappings for the whole spmdization blob and may be used by
+  // others.
+  IRMapping internalSpmdizationMap;
+  for (auto [unshardedOperand, spmdizedOperand] :
+       llvm::zip(op->getOperands(), spmdizedLinalgOpOperands)) {
+    internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+  }
+  spmdizeTriviallyShardableOperation(
+      *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
+      internalSpmdizationMap, symbolTable, builder);
+  for (Value result : op->getResults()) {
+    spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+  }
+  // Handle partial shardings.
+  createAllReduceForResultsWithoutPartialShardings(
+      op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+namespace {
+// ShardingInterface for ops that implement LinalgStructuredInterface.
+// The supported ops are only those where the indexing maps are projected
+// permutations.
+template <typename Op>
+struct StructuredOpShardingInterface
+    : public mesh::ShardingInterface::ExternalModel<
+          StructuredOpShardingInterface<Op>, Op> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
+  }
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+    SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
+    // Results must have the same indexing as destination passing style initial
+    // operands.
+    for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
+      res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
+    }
+    return res;
+  }
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+    SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+    bool allIndexingMapsAreProjectedPermutation =
+        llvm::all_of(indexingMaps, [](AffineMap map) {
+          return map.isProjectedPermutation();
+        });
+    if (!allIndexingMapsAreProjectedPermutation) {
+      // TODO: handle non-projected permutations.
+      op->emitOpError()
+          << "Only projected permutation indexing maps are supported.";
antiagainst wrote:

Nit: typically error messages start with lower case so it composes well with the prefix. Here it would read "'linag.*' op only projected ..". (Can also adjust the error message to make it read more naturally.)


More information about the llvm-commits mailing list