[llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)
Lei Zhang via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 4 22:41:09 PST 2024
================
@@ -563,6 +563,85 @@ void mesh::spmdizeFullyReplicatedOperation(
builder.clone(op, spmdizationMap);
}
+static void updateMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ &meshAxesAssignmentForLoopIterators) {
+ AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
+ unsigned loopIteratorIdx = affineDimExpr.getPosition();
+ if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+ assert(llvm::equal(meshAxesAssignmentForTensorAxis,
+ *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+ } else {
+ meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
+ llvm::to_vector(meshAxesAssignmentForTensorAxis);
+ }
+}
+
+ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<AffineMap> indexingMaps) {
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+ SmallVector<MeshShardingAttr> operatorAndResultShardings;
+ operatorAndResultShardings.reserve(operandShardings.size() +
+ resultShardings.size());
+ operatorAndResultShardings.insert(operatorAndResultShardings.end(),
+ operandShardings.begin(),
+ operandShardings.end());
+ for (auto [sharding, affineMap] :
+ llvm::zip(operatorAndResultShardings, indexingMaps)) {
----------------
antiagainst wrote:
Use `zip_equal` to be clear? Similarly for the others in this file.
https://github.com/llvm/llvm-project/pull/82284
More information about the llvm-commits
mailing list