[Mlir-commits] [llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)
Lei Zhang
llvmlistbot at llvm.org
Mon Mar 4 22:41:10 PST 2024
================
@@ -0,0 +1,336 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+static ReductionKind getReductionKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+ // Floating-point operations.
+ .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+ .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+ .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+ // Integer operations.
+ .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+ .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+ .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+ .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+ .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+ .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getReductionOp(LinalgOp op) {
+ SmallVector<Operation *> combinerOps;
+ Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+ if (!reducedValue || combinerOps.size() != 1) {
+ return std::nullopt;
+ }
+
+ return combinerOps[0];
+}
+
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+ std::optional<Operation *> reductionOp = getReductionOp(op);
+ if (!reductionOp) {
+ return ReductionKind::Generic;
+ }
+ return getReductionKind(reductionOp.value());
+}
+
+static MeshOp getMesh(Operation *op,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ SymbolTableCollection &symbolTable) {
+ for (MeshShardingAttr sharding : operandShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ for (MeshShardingAttr sharding : resultShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ assert(false);
+}
+
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+ LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+ MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+ Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+ meshOp.getSymName(), reductionMeshAxes, builder);
+ Value zero = builder.create<arith::ConstantIndexOp>(0);
+ Value isLeadProcess = builder.create<arith::CmpIOp>(
+ builder.getI1Type(), arith::CmpIPredicate::eq,
+ processLinearIndexInReductionGroup, zero);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+ isLeadProcess, true, true);
+ // Then block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ builder.create<scf::YieldOp>(spmdizedOperand);
+ }
+
+ // Else block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+ PartialReductionOpInterface partialReductionIface =
+ llvm::cast<PartialReductionOpInterface>(op.getOperation());
+ FailureOr<Operation *> reductionNeutralTensorOp =
+ partialReductionIface.generateInitialTensorForPartialReduction(
+ builder, builder.getLoc(), shape, {});
+ assert(succeeded(reductionNeutralTensorOp));
+ builder.create<scf::YieldOp>(
+ reductionNeutralTensorOp.value()->getResult(0));
+ }
+ return ifOp.getResult(0);
+}
+
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+ LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ // TODO: add support for multiple destination passing style initial value
+ // operands.
+ // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+ // needs to also support multiple DPS initial operands.
+ SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+ auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+ Value spmdizedInitOperand =
+ spmdizationMap.lookup(op->getOperands()[operandIdx]);
+ newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+ op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ return newOperands;
+}
+
+static void createAllReduceForResultWithoutPartialSharding(
+ Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+ MeshShardingAttr resultSharding, ReductionKind reductionKind,
+ IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+ SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+ [&resultSharding](MeshAxis axis) {
+ return !llvm::is_contained(resultSharding.getPartialAxes(),
+ axis);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return;
+ }
+
+ Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = builder.create<mesh::AllReduceOp>(
+ spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+ allReduceMeshAxes, reductionKind);
+ spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+}
+
+static void createAllReduceForResultsWithoutPartialShardings(
+ LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+ for (auto [unshardedLinalgOpResult, resultSharding] :
+ llvm::zip(unshardedOp->getResults(), resultShardings)) {
+ createAllReduceForResultWithoutPartialSharding(
+ unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+ reductionKind, spmdizationMap, builder);
+ }
+}
+
+static void spmdizeLinalgOpWithShardedReduction(
+ LinalgOp op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+ IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+ ImplicitLocOpBuilder &builder) {
+ MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+ SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+ SmallVector<Value> spmdizedLinalgOpOperands =
+ createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+ reductionMeshAxes,
+ spmdizationMap, builder);
+ // We must not change the operand mappings of the original spmdizationMap as
+ // they are the mappings for the whole spmdization blob and may be used by
+ // others.
+ IRMapping internalSpmdizationMap;
+ for (auto [unshardedOperand, spmdizedOperand] :
+ llvm::zip(op->getOperands(), spmdizedLinalgOpOperands)) {
+ internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+ }
+ spmdizeTriviallyShardableOperation(
+ *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
+ internalSpmdizationMap, symbolTable, builder);
+ for (Value result : op->getResults()) {
+ spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+ }
+
+ // Handle partial shardings.
+ createAllReduceForResultsWithoutPartialShardings(
+ op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+}
+
+namespace {
+
+// ShardingInterface for ops that implement LinalgStructuredInterface.
+// The supported ops are only those where the indexing maps are projected
+// permutations.
+template <typename Op>
+struct StructuredOpShardingInterface
+ : public mesh::ShardingInterface::ExternalModel<
+ StructuredOpShardingInterface<Op>, Op> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+ SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
+
+ // Results must have the same indexing as destination passing style initial
+ // operands.
+ for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
+ res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
+ }
+
+ return res;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ bool allIndexingMapsAreProjectedPermutation =
+ llvm::all_of(indexingMaps, [](AffineMap map) {
+ return map.isProjectedPermutation();
+ });
+ if (!allIndexingMapsAreProjectedPermutation) {
+ // TODO: handle non-projected permutations.
+ op->emitOpError()
+ << "Only projected permutation indexing maps are supported.";
----------------
antiagainst wrote:
Nit: typically error messages start with lower case so it composes well with the prefix. Here it would read "'linag.*' op only projected ..". (Can also adjust the error message to make it read more naturally.)
https://github.com/llvm/llvm-project/pull/82284
More information about the Mlir-commits
mailing list