[Mlir-commits] [mlir] 466abaf - Revert "[MLIR][Mesh] Add sharding propagation pass (#69665)"
Mehdi Amini
llvmlistbot at llvm.org
Fri Nov 3 17:53:17 PDT 2023
Author: Mehdi Amini
Date: 2023-11-03T17:52:41-07:00
New Revision: 466abaf152921875fa91f2c75d675c4ce0e2acc5
URL: https://github.com/llvm/llvm-project/commit/466abaf152921875fa91f2c75d675c4ce0e2acc5
DIFF: https://github.com/llvm/llvm-project/commit/466abaf152921875fa91f2c75d675c4ce0e2acc5.diff
LOG: Revert "[MLIR][Mesh] Add sharding propagation pass (#69665)"
This reverts commit 9d9400d7de9b928e3018af97e8b381a4a6ba5162.
This reverts commit bda763aea0b854178c01eac9f309042d9aaa823b.
The buildbot is broken and tests are failing.
Added:
Modified:
mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
mlir/include/mlir/IR/AffineMap.h
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/Mesh/CMakeLists.txt
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Tosa/CMakeLists.txt
mlir/lib/IR/AffineMap.cpp
Removed:
mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
mlir/test/Dialect/Mesh/sharding-propagation.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index fa8842fb04fd721..f33061b2d87cffc 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1,3 +1 @@
-add_subdirectory(Interfaces)
add_subdirectory(IR)
-add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..39d24595ec1c446 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,22 +49,6 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
-// Mesh_IteratorType and Mesh_Partial are used to annotate
diff erent aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
- I32EnumAttrCase<"Parallel", 1, "parallel">,
- I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
- I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
- I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
- I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
- I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::mesh";
-}
-
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -138,24 +122,6 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
$partial_axes^ `]`)? `>`
}];
- let builders = [
- AttrBuilder<(ins "SymbolRefAttr":$cluster,
- "ArrayRef<SmallVector<int32_t>>":$split_axes,
- "ArrayRef<int32_t>": $partial_axes,
- "mesh::Partial": $partial_type), [{
- SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
- split_axes, [&](ArrayRef<int32_t> array) {
- return DenseI32ArrayAttr::get($_ctxt, array);
- });
- return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
- partial_type);
- }]>,
- AttrBuilder<(ins "SymbolRefAttr":$cluster,
- "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
- }]>
- ];
-
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..9dfeca84d012165 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,22 +24,4 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
-namespace mlir {
-namespace mesh {
-
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
-
-template <typename T>
-void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
- while (!array.empty() && array.back().empty())
- array.pop_back();
-}
-
-Partial getPartialTypeFromReduction(IteratorType iType);
-
-} // namespace mesh
-} // namespace mlir
-
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
deleted file mode 100644
index b3a44f3b0089abc..000000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
-mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
-mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
deleted file mode 100644
index d860628cf371aa9..000000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ /dev/null
@@ -1,68 +0,0 @@
-//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
-
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-
-class Operation;
-
-namespace mesh {
-
-using ShardingArray = SmallVector<SmallVector<int32_t>>;
-using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
-
-struct ShardingOption {
- // An array of int array. The sub-array at the i-th position signifies the
- // mesh axes the i-th loop will be sharded on.
- ShardingArray shardingArray;
- SymbolRefAttr cluster;
- // `empty` being true indicates that no sharding information can be inferred
- // at present. Note that it is
diff erent from the case where an operation is
- // not sharded.
- bool empty = false;
- ShardingOption() = default;
- ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
- : shardingArray(std::move(shardingArray)), cluster(cluster) {}
-};
-
-// This method retrieves the 'MeshShardingAttr' attribute from a given operation
-// result and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpResult result);
-
-// This method retrieves the 'MeshShardingAttr' attribute from a given operation
-// operand and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpOperand &opOperand);
-
-namespace detail {
-
-FailureOr<ShardingOption>
-defaultGetShardingOption(Operation *op,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings);
-
-LogicalResult
-defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
- const ShardingOption &shardingOption);
-
-} // namespace detail
-
-} // namespace mesh
-
-} // namespace mlir
-
-/// Include the ODS generated interface header files.
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
-
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
deleted file mode 100644
index 21b6c8d4f599a8d..000000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ /dev/null
@@ -1,102 +0,0 @@
-//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
-
-include "mlir/IR/OpBase.td"
-
-def ShardingInterface : OpInterface<"ShardingInterface"> {
- let description = [{
- Interface for allowing operations to expose information needed to
- shard them.
- }];
- let cppNamespace = "::mlir::mesh";
-
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Returns a list of iterator types that describe the number of loops.
- The iterator types determine how the operation traverses its input and
- output tensors.
-
- Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
- types are parallel, parallel, reduction-sum. This indicates that M and
- N are traversed in parallel, while the K dimension is used for
- reduction.
-
- Example 2: A softmax op's loop iterator types are parallel and
- invalid. The second dimension is considered as invalid because it is
- neither parallel nor any kind of reduction.
- }],
- /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
- /*methodName=*/"getLoopIteratorTypes",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return {};"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps attribute within the current operation.
- Indexing maps determine how indices in the iteration space map to
- tensor indices. They are specified using `affine_map` in MLIR, which
- provides an affine transformation of indices.
- }],
- /*retTy=*/"SmallVector<AffineMap>",
- /*methodName=*/"getIndexingMaps",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return {};"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Given that certain operands or results of the operation may have
- sharding annotations, this method leverages this information to deduce
- how the operation should be sharded.
- }],
- /*retTy=*/"FailureOr<ShardingOption>",
- /*methodName=*/"getShardingOption",
- /*args=*/(ins
- "ArrayRef<MeshShardingAttr>": $operandShardings,
- "ArrayRef<MeshShardingAttr>": $resultShardings
- ),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return detail::defaultGetShardingOption(
- $_op.getOperation(), operandShardings, resultShardings);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Based on a given ShardingOption, this method adds `mesh.shard`
- operations for the operands and results that previously lacked
- sharding annotations.
- }],
- /*retTy=*/"LogicalResult",
- /*methodName=*/"addShardingAnnotations",
- /*args=*/(ins
- "OpBuilder &":$b,
- "const ShardingOption &":$shardingOption
- ),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return detail::defaultAddShardingAnnotations(
- $_op.getOperation(), b, shardingOption);
- }]
- >
- ];
-
- let extraClassDeclaration = [{
- LogicalResult verifyShardingInterfaceImpl();
-
- void printLoopTypesAndIndexingMaps(raw_ostream &os);
- }];
-}
-
-
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
deleted file mode 100644
index 8d768485103b65f..000000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
-add_public_tablegen_target(MLIRMeshPassIncGen)
-add_dependencies(mlir-headers MLIRMeshPassIncGen)
-
-add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
deleted file mode 100644
index 83399d10beaae48..000000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ /dev/null
@@ -1,39 +0,0 @@
-//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
-
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-
-namespace func {
-class FuncOp;
-}
-
-namespace mesh {
-
-//===----------------------------------------------------------------------===//
-// Passes
-//===----------------------------------------------------------------------===//
-
-#define GEN_PASS_DECL
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
-
-//===----------------------------------------------------------------------===//
-// Registration
-//===----------------------------------------------------------------------===//
-
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
-
-} // namespace mesh
-} // namespace mlir
-
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
deleted file mode 100644
index c09cf3e710d4278..000000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ /dev/null
@@ -1,32 +0,0 @@
-//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
-#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
-
-include "mlir/Pass/PassBase.td"
-
-//===----------------------------------------------------------------------===//
-// ShardingPropagation
-//===----------------------------------------------------------------------===//
-
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
- let summary = "sharding propagation";
- let description = [{
- Propagates sharding information throughout the graph. After this pass, each
- of the operations' operands and results is annotated with a `mesh.shard`
- operation, and the operations themselves are added with sharding option
- attributes.
- }];
- let dependentDialects = [
- "mesh::MeshDialect"
- ];
-}
-
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
deleted file mode 100644
index 16427919dace5da..000000000000000
--- a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- ShardingInterfaceImpl.h - ------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
-#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
-
-namespace mlir {
-
-class DialectRegistry;
-
-namespace tosa {
-
-void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
-
-} // namespace tosa
-} // namespace mlir
-
-#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index f691a3daf8889c5..5af7835258f6bd2 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -104,18 +104,6 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
- /// Returns an affine map with `numDims` input dimensions and results
- /// specified by `targets`.
- ///
- /// Examples:
- /// * getMultiDimMapWithTargets(3, [0, 2, 1])
- /// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
- /// * getMultiDimMapWithTargets(3, [2, 1])
- /// -> affine_map<(d0, d1, d2) -> (d2, d1)>
- static AffineMap getMultiDimMapWithTargets(unsigned numDims,
- ArrayRef<unsigned> targets,
- MLIRContext *context);
-
/// Returns a vector of AffineMaps; each with as many results as
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
/// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 395d899f9ad84b0..621110d130818d3 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,7 +79,6 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
-#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -172,7 +171,6 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
- tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerSubsetOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index f22980036ffcfa1..80894094484b999 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -30,7 +30,6 @@
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -75,7 +74,6 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
- mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index fa8842fb04fd721..f33061b2d87cffc 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,3 +1 @@
-add_subdirectory(Interfaces)
add_subdirectory(IR)
-add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..fc91fd994f12dc2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,37 +41,6 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
return arith::ConstantOp::materialize(builder, value, type, loc);
}
-//===----------------------------------------------------------------------===//
-// Mesh utilities
-//===----------------------------------------------------------------------===//
-
-bool mesh::isReductionLoop(IteratorType iType) {
- return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
- return (partial == Partial::Generic &&
- iType == IteratorType::ReductionGeneric) ||
- (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
- (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
- (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
- switch (iType) {
- case IteratorType::ReductionGeneric:
- return Partial::Generic;
- case IteratorType::ReductionSum:
- return Partial::Sum;
- case IteratorType::ReductionMax:
- return Partial::Max;
- case IteratorType::ReductionMin:
- return Partial::Min;
- default:
- llvm_unreachable("No corresponding partial type can be found");
- }
-}
-
//===----------------------------------------------------------------------===//
// mesh.cluster op
//===----------------------------------------------------------------------===//
@@ -126,6 +95,7 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
if (failed(checkMeshAxis(partialAxes)))
return failure();
+
return success();
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
deleted file mode 100644
index 1010756f1fe279a..000000000000000
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-add_mlir_library(MLIRShardingInterface
- ShardingInterface.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
-
- DEPENDS
- MLIRShardingInterfaceIncGen
-
- LINK_LIBS PUBLIC
- MLIRDialectUtils
- MLIRIR
- MLIRMeshDialect
- MLIRSupport
-)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
deleted file mode 100644
index c2e1d1c726816a5..000000000000000
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ /dev/null
@@ -1,524 +0,0 @@
-//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SmallSet.h"
-#include "llvm/Support/Debug.h"
-
-#include <algorithm>
-#include <utility>
-
-#define DEBUG_TYPE "sharding-interface"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-
-using namespace mlir;
-using namespace mlir::mesh;
-
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
-
-//===----------------------------------------------------------------------===//
-// common util functions
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-checkOperandAffineExprRecursively(AffineExpr expr,
- SmallVectorImpl<bool> &seenIds) {
- switch (expr.getKind()) {
- case AffineExprKind::Add: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- AffineExpr lhs = binOpExpr.getLHS();
- AffineExpr rhs = binOpExpr.getRHS();
- if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
- return failure();
- if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
- return failure();
- return success();
- }
- case AffineExprKind::Mul: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- AffineExpr lhs = binOpExpr.getLHS();
- AffineExpr rhs = binOpExpr.getRHS();
- AffineExpr dimExpr;
- if (lhs.getKind() == AffineExprKind::DimId &&
- rhs.getKind() == AffineExprKind::Constant) {
- dimExpr = lhs;
- } else if (rhs.getKind() == AffineExprKind::DimId &&
- lhs.getKind() == AffineExprKind::Constant) {
- dimExpr = rhs;
- } else
- return failure();
- unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
- if ((size_t)position >= seenIds.size() || seenIds[position])
- return failure();
- seenIds[position] = true;
- return success();
- }
- case AffineExprKind::DimId: {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
- if ((size_t)position >= seenIds.size() || seenIds[position])
- return failure();
- seenIds[position] = true;
- return success();
- }
- default:
- return failure();
- }
-}
-
-static FailureOr<llvm::SmallSet<unsigned, 2>>
-checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
- SmallVector<bool> seenIds(numDims, false);
- if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
- return failure();
-
- llvm::SmallSet<unsigned, 2> positions;
- for (auto it : llvm::enumerate(seenIds)) {
- if (it.value())
- positions.insert((unsigned)it.index());
- }
- return positions;
-}
-
-//===----------------------------------------------------------------------===//
-// mesh::getMeshShardingAttr
-//===----------------------------------------------------------------------===//
-
-FailureOr<std::pair<bool, MeshShardingAttr>>
-mesh::getMeshShardingAttr(OpResult result) {
- Value val = result.cast<Value>();
- bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
- if (!shardOp)
- return false;
- return !shardOp.getAnnotateForUsers();
- });
-
- if (anyShardedForDef) {
- // expected to have exact one use if it has a use of `mesh.shard` without
- // unit attr annotate_for_users
- if (!val.hasOneUse())
- return failure();
- auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
- return std::make_pair(false, shardOp.getShard());
- }
-
- bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
- if (!shardOp)
- return false;
- return shardOp.getAnnotateForUsers();
- });
- if (anyShardedForUsers) {
- SmallVector<ShardOp> shardOps;
- for (Operation *user : val.getUsers()) {
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
- if (shardOp)
- shardOps.push_back(shardOp);
- }
- MeshShardingAttr shardForDef = shardOps[0].getShard();
- for (size_t i = 1; i < shardOps.size(); ++i) {
- // TODO: Deduce a reasonable mesh sharding attr for def when they are
- //
diff erent
- assert(shardOps[i].getShard() == shardForDef &&
- "only support all shard ops have the same mesh sharding attr");
- }
- return std::make_pair(true, shardForDef);
- }
- return failure();
-}
-
-FailureOr<std::pair<bool, MeshShardingAttr>>
-mesh::getMeshShardingAttr(OpOperand &opOperand) {
- Value val = opOperand.get();
- if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
- return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
-
- return failure();
-}
-
-//===----------------------------------------------------------------------===//
-// ShardingInterface::verifyShardingInterfaceImpl
-//===----------------------------------------------------------------------===//
-
-LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
- Operation *op = getOperation();
-
- // check operands and results type
- for (Type type : op->getOperandTypes())
- if (!llvm::isa<RankedTensorType>(type))
- return failure();
- for (Type type : op->getResultTypes())
- if (!llvm::isa<RankedTensorType>(type))
- return failure();
-
- // check loop types
- SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
- if (loopTypes.size() == 0)
- return failure();
-
- // check maps
- SmallVector<AffineMap> maps = getIndexingMaps();
- if (maps.size() == 0)
- return failure();
- unsigned numOperands = op->getNumOperands();
- unsigned numResults = op->getNumResults();
- if (numOperands + numResults != maps.size())
- return failure();
-
- for (OpResult result : op->getResults()) {
- auto resultType = result.getType().dyn_cast<RankedTensorType>();
- if (!resultType)
- return failure();
- AffineMap map = maps[numOperands + result.getResultNumber()];
- if (!map.isProjectedPermutation()) {
- return failure();
- }
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ShardingInterface::printLoopTypesAndIndexingMaps
-//===----------------------------------------------------------------------===//
-
-void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
- os << "print loop types and indexing maps for: \n";
- getOperation()->print(os);
- os << "\n";
- os << "loop types: [";
- for (IteratorType type : getLoopIteratorTypes()) {
- os << stringifyEnum(type) << " ";
- }
- os << "]\n";
- os << "indexing maps: \n";
- for (AffineMap map : getIndexingMaps())
- os << map << "\n";
- os << "\n";
-}
-
-//===----------------------------------------------------------------------===//
-// detail::defaultGetShardingOption
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
-static LogicalResult fillShardingOption(Operation *op,
- ShardingOption &shardingOption,
- SymbolRefAttr cluster,
- ArrayRef<int32_t> meshAxes,
- unsigned loopIdx) {
- if ((shardingOption.cluster && cluster &&
- shardingOption.cluster != cluster) ||
- (!shardingOption.shardingArray[loopIdx].empty() &&
- shardingOption.shardingArray[loopIdx] != meshAxes)) {
- LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
- << loopIdx << "\n");
- return failure();
- }
- for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
- if (i == loopIdx)
- continue;
-
- for (int32_t axis : meshAxes) {
- if (std::find(shardingOption.shardingArray[i].begin(),
- shardingOption.shardingArray[i].end(),
- axis) != shardingOption.shardingArray[i].end()) {
- LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
- << axis << " duplicate");
- return failure();
- }
- }
- }
- if (cluster)
- shardingOption.cluster = cluster;
- if (shardingOption.shardingArray[loopIdx].empty())
- shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
- meshAxes.end());
- return success();
-}
-
-} // namespace
-
-FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
- Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings) {
- ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
- ShardingOption shardingOption;
-
- if (failed(shardingOp.verifyShardingInterfaceImpl()))
- return op->emitOpError() << "invalid sharding interface implementation";
- SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
- SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
- unsigned numOperands = op->getNumOperands();
- shardingOption.shardingArray.resize(loopTypes.size());
- llvm::SmallVector<int32_t> partialMeshAxes;
- Partial partialType;
- llvm::SmallSet<unsigned, 4> visitedLoopIndices;
- bool anyShardingInResultsOrOperands = false;
-
- // 1. Fill sharding option based on op results
- for (auto shardingIt : llvm::enumerate(resultShardings)) {
- MeshShardingAttr shardAttr = shardingIt.value();
- if (!shardAttr)
- continue;
- AffineMap map = maps[numOperands + shardingIt.index()];
- anyShardingInResultsOrOperands = true;
- // Handle the split axes: calculate the corresponding loop index for each
- // split axes sub-array, and then store the sub-array to
- // shardingOption[index]
- for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
- AffineExpr expr = std::get<0>(it);
- ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
- auto dim = expr.cast<AffineDimExpr>();
- unsigned index = dim.getPosition();
- visitedLoopIndices.insert(index);
- if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
- axes, index)))
- return failure();
- }
-
- // Handle the partial axes: at this stage, the exact loop index/indices
- // cannot be decided because there could be multiple reduction loops.
- ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
- if (!partialAxes.empty()) {
- if (!partialMeshAxes.empty())
- return op->emitOpError() << "at most one result with partial axes is "
- "supported at present";
- partialType = shardAttr.getPartialType();
- partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
- // Add all the reduction loop indices to `visitedLoopIndices` if
- // `partialAxes` is not empty
- for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
- if (isReductionLoop(loopTypes[loopIdx]))
- visitedLoopIndices.insert(loopIdx);
- }
- }
- }
-
- // 2. Fill sharding option based on operands
- for (auto shardingIt : llvm::enumerate(operandShardings)) {
- MeshShardingAttr shardAttr = shardingIt.value();
- if (!shardAttr)
- continue;
-
- anyShardingInResultsOrOperands = true;
- AffineMap map = maps[shardingIt.index()];
- unsigned numDims = map.getNumDims();
-
- // Handle the split axes. Partial axes don't need to be handled because they
- // only affect the defining op of the operand.
- //
- // TODO: Change to process the operands with single loop index first and
- // then the operands with multiple loop indices.
- for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
- AffineExpr expr = std::get<0>(it);
- ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
- FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
- checkOperandAffineExpr(expr, numDims);
- if (failed(loopIndices))
- return op->emitOpError()
- << "operand's affine expression is restricted to const_i * "
- "dim_i + const_j + dim_j + ...";
- if (loopIndices->empty())
- continue;
- if (loopIndices->size() == 1) {
- unsigned loopIdx = *loopIndices->begin();
- visitedLoopIndices.insert(loopIdx);
- if (failed(fillShardingOption(op, shardingOption,
- shardAttr.getCluster(), axes, loopIdx)))
- return failure();
- }
- // If multiple loop indices correspond to a dimension of an operand, it is
- //
diff icult to infer which loop indices are responsible for sharding.
- // Therefore, the exact loop index must be specified by others.
- if (loopIndices->size() > 1) {
- bool seenLoopIndices = false;
- for (unsigned loopIdx : *loopIndices) {
- if (visitedLoopIndices.contains(loopIdx)) {
- seenLoopIndices = true;
- break;
- }
- }
- if (!seenLoopIndices)
- return op->emitOpError()
- << "the operand " << shardingIt.index()
- << " has multiple loop indices in a dimension, but none of "
- "them could be found in the exactly specified annotation "
- "of op results or operands.";
- }
- }
- }
-
- // 3. Finalize sharding option
- if (!partialMeshAxes.empty()) {
- bool anyNonEmptyReductionLoop = llvm::any_of(
- llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
- SmallVector<int32_t> &subArray = it.value();
- int64_t idx = it.index();
- return isReductionLoop(loopTypes[idx]) && !subArray.empty();
- });
- if (!anyNonEmptyReductionLoop) {
- bool filled = false;
- for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
- if (isReductionLoop(loopTypes[idx]) &&
- areReductionAndPartialMatch(loopTypes[idx], partialType)) {
- std::ignore = fillShardingOption(op, shardingOption, nullptr,
- partialMeshAxes, idx);
- filled = true;
- break;
- }
- }
- if (!filled)
- return op->emitOpError() << "no matched reduction loop found for the "
- "result's partial type";
- }
- }
- removeTrailingEmptySubArray(shardingOption.shardingArray);
- if (!anyShardingInResultsOrOperands)
- shardingOption.empty = true;
- return shardingOption;
-}
-
-//===----------------------------------------------------------------------===//
-// detail::defaultAddShardingAnnotations
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// To add a `mesh.shard` op for the given result, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
-static LogicalResult addShardOp(OpBuilder &b, OpResult result,
- const ShardingOption &shardingOption,
- AffineMap map,
- ArrayRef<IteratorType> loopTypes) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
- getMeshShardingAttr(result);
- if (succeeded(maybeSharding) && !maybeSharding->first)
- return success();
-
- auto resultType = result.getType().cast<RankedTensorType>();
- SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
- SmallVector<int32_t> partialAxes;
-
- // process the split axes
- for (auto it : llvm::enumerate(map.getResults())) {
- AffineExpr expr = it.value();
- // `expr` must be an `AffineDimExpr` because `map` is verified by
- // isProjectedPermutation
- auto dim = expr.cast<AffineDimExpr>();
- unsigned loopIdx = dim.getPosition();
- if (loopIdx < shardingOption.shardingArray.size())
- splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
- }
-
- // process the partial axes
- Partial partialType;
- for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
- IteratorType iType = std::get<0>(it);
- if (isReductionLoop(iType)) {
- Partial curPartialType = getPartialTypeFromReduction(iType);
- if (!partialAxes.empty())
- assert(partialType == curPartialType &&
- "Only one reduction type is supported");
- partialType = curPartialType;
- const SmallVector<int32_t> &axis = std::get<1>(it);
- partialAxes.append(axis);
- }
- }
-
- removeTrailingEmptySubArray(splitAxes);
- MeshShardingAttr shardAttr =
- MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
- partialAxes, partialType);
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointAfterValue(result);
- auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
- shardAttr, /*annotate_for_users*/ false);
- result.replaceAllUsesExcept(shardOp, shardOp);
- return success();
-}
-
-// To add a `mesh.shard` op for the given operand, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
-static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
- const ShardingOption &shardingOption,
- AffineMap map,
- ArrayRef<IteratorType> loopTypes) {
- auto maybeShardingAttr = getMeshShardingAttr(opOperand);
- if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
- return success();
- Value operand = opOperand.get();
- auto operandType = operand.getType().cast<RankedTensorType>();
- SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
- unsigned numDims = map.getNumDims();
- for (auto it : llvm::enumerate(map.getResults())) {
- int64_t idx = it.index();
- AffineExpr expr = it.value();
- FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
- checkOperandAffineExpr(expr, numDims);
- if (failed(loopIndices))
- return failure();
- SmallVector<unsigned> shardedLoopIndices;
- for (unsigned loopIdx : *loopIndices) {
- if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
- !shardingOption.shardingArray[loopIdx].empty())
- shardedLoopIndices.push_back(loopIdx);
- }
- // mostly one sharded loop index is accepted
- if (shardedLoopIndices.size() > 1)
- return failure();
- if (shardedLoopIndices.size() == 1) {
- splitAxes[idx].append(
- shardingOption.shardingArray[shardedLoopIndices[0]]);
- }
- }
-
- removeTrailingEmptySubArray(splitAxes);
- MeshShardingAttr shardAttr =
- MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(opOperand.getOwner());
- auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
- shardAttr, true);
- opOperand.set(shardOp);
-
- return success();
-}
-
-} // namespace
-
-LogicalResult mesh::detail::defaultAddShardingAnnotations(
- Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
- ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
- SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
- SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
- unsigned numOperands = op->getNumOperands();
-
- // 1. add mesh.shard ops for all op results
- for (OpResult result : op->getResults()) {
- if (failed(addShardOp(b, result, shardingOption,
- maps[numOperands + result.getResultNumber()],
- loopTypes)))
- return failure();
- }
-
- // 2. add mesh.shard ops for all operands
- for (OpOperand &opOperand : op->getOpOperands()) {
- if (failed(addShardOp(b, opOperand, shardingOption,
- maps[opOperand.getOperandNumber()], loopTypes)))
- return failure();
- }
-
- return success();
-}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
deleted file mode 100644
index 9f85d8e9cb22d5b..000000000000000
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-add_mlir_dialect_library(MLIRMeshTransforms
- ShardingPropagation.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
-
- DEPENDS
- MLIRMeshPassIncGen
- MLIRShardingInterface
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRMeshDialect
- MLIRShardingInterface
- MLIRSupport
- MLIRTosaShardingInterfaceImpl
-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
deleted file mode 100644
index 3aed912fb43c63e..000000000000000
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ /dev/null
@@ -1,210 +0,0 @@
-//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Pass/Pass.h"
-#include "llvm/Support/Debug.h"
-#include <vector>
-
-namespace mlir {
-namespace mesh {
-#define GEN_PASS_DEF_SHARDINGPROPAGATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
-} // namespace mesh
-} // namespace mlir
-
-#define DEBUG_TYPE "sharding-propagation"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-
-using namespace mlir;
-using namespace mlir::mesh;
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Utilities
-//===----------------------------------------------------------------------===//
-
-// This method retrieves all potential sharding attributes, prioritizing
-// specific shardings. For example, mustShardings = [shard0, None] and
-// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
-// [shard0, None]]
-static SmallVector<SmallVector<MeshShardingAttr>>
-getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
- ArrayRef<MeshShardingAttr> optionalShardings) {
- SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs;
- SmallVector<MeshShardingAttr> curShardingAttrs;
-
- std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
- if (i == mustShardings.size()) {
- allShardingAttrs.push_back(
- SmallVector<MeshShardingAttr>(curShardingAttrs));
- return;
- }
-
- if (mustShardings[i]) {
- curShardingAttrs.push_back(mustShardings[i]);
- dfsCreateShardingAttrs(i + 1);
- curShardingAttrs.pop_back();
- return;
- }
-
- if (optionalShardings[i]) {
- curShardingAttrs.push_back(optionalShardings[i]);
- dfsCreateShardingAttrs(i + 1);
- curShardingAttrs.pop_back();
- curShardingAttrs.push_back(nullptr);
- dfsCreateShardingAttrs(i + 1);
- curShardingAttrs.pop_back();
- return;
- }
-
- curShardingAttrs.push_back(nullptr);
- dfsCreateShardingAttrs(i + 1);
- curShardingAttrs.pop_back();
- };
-
- dfsCreateShardingAttrs(0);
- return allShardingAttrs;
-}
-
-// For each operation that implements the ShardingInterface, infer the sharding
-// option of the operation from its operands and/or results using the
-// `getShardingOption` method. If the inferred sharding option is not empty, add
-// a `mesh.shard` operation for all remaining operands and results that do not
-// have sharding annotations.
-LogicalResult visitOp(Operation *op, OpBuilder &builder) {
- if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
- return success();
-
- ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
- if (!shardingOp) {
- op->emitOpError() << "sharding interface is not implemented.";
- return failure();
- }
-
- // collect MeshShardingAttr from results
- SmallVector<MeshShardingAttr> allowConflictsResultShardings;
- allowConflictsResultShardings.resize(op->getNumResults());
- SmallVector<MeshShardingAttr> resultMustShardings;
- resultMustShardings.resize(op->getNumResults());
- for (OpResult result : op->getResults()) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
- getMeshShardingAttr(result);
- if (failed(maybeShardAttr))
- continue;
- if (!maybeShardAttr->first)
- resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
- else
- allowConflictsResultShardings[result.getResultNumber()] =
- maybeShardAttr->second;
- }
-
- // collect MeshShardingAttr from operands
- SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
- allowConflictsOperandShardings.resize(op->getNumOperands());
- SmallVector<MeshShardingAttr> operandMustShardings;
- operandMustShardings.resize(op->getNumOperands());
- for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
- getMeshShardingAttr(opOperand);
- if (failed(maybeShardAttr))
- continue;
-
- if (maybeShardAttr->first)
- operandMustShardings[opOperand.getOperandNumber()] =
- maybeShardAttr->second;
- else
- allowConflictsOperandShardings[opOperand.getOperandNumber()] =
- maybeShardAttr->second;
- }
-
- // try to get the sharding option
- SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
- getOrderedPossibleShardingAttrs(operandMustShardings,
- allowConflictsOperandShardings);
- SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
- getOrderedPossibleShardingAttrs(resultMustShardings,
- allowConflictsResultShardings);
- FailureOr<ShardingOption> finalShardingOption = failure();
- for (ArrayRef<MeshShardingAttr> resultShardings :
- possibleResultShardingAttrs) {
- if (succeeded(finalShardingOption))
- break;
- for (ArrayRef<MeshShardingAttr> operandShardings :
- possibleOperandShardingAttrs) {
- FailureOr<ShardingOption> shardingOption =
- shardingOp.getShardingOption(operandShardings, resultShardings);
- if (succeeded(shardingOption)) {
- finalShardingOption = shardingOption;
- break;
- }
- }
- }
-
- if (failed(finalShardingOption)) {
- op->emitOpError() << "fail to get sharding option.";
- return failure();
- }
- // sharding info is empty, return immediately
- if (finalShardingOption->empty)
- return success();
-
- if (failed(
- shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
- op->emitOpError() << "fail to set sharding annotations.";
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ShardingPropagation
-//===----------------------------------------------------------------------===//
-struct ShardingPropagation
- : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
- void runOnOperation() override {
- func::FuncOp funcOp = getOperation();
- MLIRContext *ctx = funcOp.getContext();
- Region ®ion = funcOp.getBody();
- OpBuilder builder(ctx);
- if (!region.hasOneBlock()) {
- funcOp.emitOpError() << "only one block is supported!";
- signalPassFailure();
- }
- Block &block = region.front();
-
- LLVM_DEBUG(
- DBGS() << "print all the ops' iterator types and indexing maps in the "
- "block.\n";
- for (Operation &op
- : block.getOperations()) {
- if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
- shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
- });
-
- // 1. propagate in reversed order
- for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
-
- LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
- << funcOp << "\n");
-
- // 2. propagate in original order
- for (Operation &op : llvm::make_early_inc_range(block))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
- }
-};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index ba5343dcd7ac6c1..8e32579e0e4c2e3 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -26,18 +26,4 @@ add_mlir_dialect_library(MLIRTosaDialect
MLIRViewLikeInterface
)
-add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
- IR/ShardingInterfaceImpl.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRMeshDialect
- MLIRShardingInterface
- MLIRSupport
- MLIRTosaDialect
- )
-
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
deleted file mode 100644
index dace86533c0e231..000000000000000
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ /dev/null
@@ -1,111 +0,0 @@
-//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/DialectRegistry.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "tosa-sharding-impl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-
-using namespace mlir;
-using namespace mlir::tosa;
-using namespace mlir::mesh;
-
-namespace {
-
-template <typename ElemwiseOp>
-struct ElemwiseSharding
- : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
- ElemwiseOp> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
- Value val = op->getOperand(0);
- auto type = val.getType().dyn_cast<RankedTensorType>();
- if (!type)
- return {};
- SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
- return types;
- }
-
- SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
- MLIRContext *ctx = op->getContext();
- Value val = op->getOperand(0);
- auto type = val.getType().dyn_cast<RankedTensorType>();
- if (!type)
- return {};
- int64_t rank = type.getRank();
- int64_t num = op->getNumOperands() + op->getNumResults();
- SmallVector<AffineMap> maps(num,
- AffineMap::getMultiDimIdentityMap(rank, ctx));
- return maps;
- }
-};
-
-// loop types: [parallel, parallel, parallel, reduction_sum]
-// indexing maps:
-// (d0, d1, d2, d3) -> (d0, d1, d3)
-// (d0, d1, d2, d3) -> (d0, d3, d2)
-// (d0, d1, d2, d3) -> (d0, d1, d2)
-struct MatMulOpSharding
- : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
- auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
- if (!tensorType)
- return {};
-
- SmallVector<IteratorType> types(tensorType.getRank() + 1,
- IteratorType::Parallel);
- types[tensorType.getRank()] = IteratorType::ReductionSum;
- return types;
- }
-
- SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
- auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
- if (!tensorType)
- return {};
- MLIRContext *ctx = op->getContext();
- SmallVector<AffineMap> maps;
- maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
- maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
- maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
- return maps;
- }
-};
-
-template <typename OpType>
-static void registerElemwiseOne(MLIRContext *ctx) {
- OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
-}
-
-/// Variadic helper function.
-template <typename... OpTypes>
-static void registerElemwiseAll(MLIRContext *ctx) {
- (registerElemwiseOne<OpTypes>(ctx), ...);
-}
-
-} // namespace
-
-void mlir::tosa::registerShardingInterfaceExternalModels(
- DialectRegistry ®istry) {
-
- registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
- registerElemwiseAll<
- ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
- BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp,
- LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
- MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
- LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
- GreaterOp, GreaterEqualOp>(ctx);
-
- MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
- });
-}
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index cdcd71cdd7cd151..3bd1181b6c7bbd8 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -227,23 +227,15 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context) {
assert(!permutation.empty() &&
"Cannot create permutation map from empty permutation vector");
+ SmallVector<AffineExpr, 4> affExprs;
+ for (auto index : permutation)
+ affExprs.push_back(getAffineDimExpr(index, context));
const auto *m = std::max_element(permutation.begin(), permutation.end());
- auto permutationMap = getMultiDimMapWithTargets(*m + 1, permutation, context);
+ auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
assert(permutationMap.isPermutation() && "Invalid permutation vector");
return permutationMap;
}
-AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
- ArrayRef<unsigned> targets,
- MLIRContext *context) {
- SmallVector<AffineExpr, 4> affExprs;
- for (unsigned t : targets)
- affExprs.push_back(getAffineDimExpr(t, context));
- AffineMap result = AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0,
- affExprs, context);
- return result;
-}
-
template <typename AffineExprContainer>
static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
deleted file mode 100644
index bda407b52bfd4f2..000000000000000
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ /dev/null
@@ -1,188 +0,0 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
-
-mesh.cluster @mesh_1d(rank = 1)
-mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
-mesh.cluster @mesh_3d(rank = 3)
-
-// CHECK-LABEL: func.func @element_wise_empty_sharding_info
-func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: tosa.sigmoid
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: return
- return %0 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_def
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V2]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_use
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V2]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_output
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_input
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @arrow_structure
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
- // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V8:.*]] = tosa.negate %[[V7]]
- // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V6]], %[[V9]]
- return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
- %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
- %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
- %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
- %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
- %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %2 : tensor<2x16x32xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(a)
-// CHECK-LABEL: func.func @mlp_1d_weight_stationary
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
-func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
- // CHECK: %[[V0:.*]] = tosa.matmul
- %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
- %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V6:.*]] = mesh.shard %[[ARG2]] to <@mesh_1d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x32x8xf32>
- // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]]
- %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
- // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_1d, {{\[\[}}], [], []], partial = sum[0]> : tensor<2x4x8xf32>
- %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
- %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: return %[[V9]]
- return %5 : tensor<2x4x8xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(b)
-// CHECK-LABEL: func.func @mlp_2d_weight_stationary
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
-func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
- // CHECK-DAG: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> : tensor<2x4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_3d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[ARG1]] to <@mesh_3d, {{\[\[}}], [0], [1, 2]]> annotate_for_users : tensor<2x8x32xf32>
- // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]]
- %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_3d, {{\[\[}}], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
- %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]]
- %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V7:.*]] = mesh.shard %[[V6]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[ARG2]] to <@mesh_3d, {{\[\[}}], [1, 2], [0]]> annotate_for_users : tensor<2x32x8xf32>
- // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]]
- %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
- // CHECK-DAG: %[[V11:.*]] = mesh.shard %[[V10]] to <@mesh_3d, {{\[\[}}], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
- %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V12:.*]] = mesh.shard %[[V11]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
- %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: return %[[V12]]
- return %6 : tensor<2x4x8xf32>
-}
More information about the Mlir-commits
mailing list