[llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)

Lei Zhang via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 22:41:09 PST 2024


================
@@ -563,6 +563,85 @@ void mesh::spmdizeFullyReplicatedOperation(
   builder.clone(op, spmdizationMap);
 }
 
+static void updateMeshAxisAssignmentForLoopIterators(
+    ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+    SmallVector<std::optional<SmallVector<MeshAxis>>>
+        &meshAxesAssignmentForLoopIterators) {
+  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
+  unsigned loopIteratorIdx = affineDimExpr.getPosition();
+  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+    assert(llvm::equal(meshAxesAssignmentForTensorAxis,
+                       *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+  } else {
+    meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
+        llvm::to_vector(meshAxesAssignmentForTensorAxis);
+  }
+}
+
+ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<AffineMap> indexingMaps) {
+  SmallVector<std::optional<SmallVector<MeshAxis>>>
+      meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+  SmallVector<MeshShardingAttr> operatorAndResultShardings;
+  operatorAndResultShardings.reserve(operandShardings.size() +
+                                     resultShardings.size());
+  operatorAndResultShardings.insert(operatorAndResultShardings.end(),
+                                    operandShardings.begin(),
+                                    operandShardings.end());
+  for (auto [sharding, affineMap] :
+       llvm::zip(operatorAndResultShardings, indexingMaps)) {
+    if (!sharding) {
+      continue;
+    }
+    for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+         llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
+      updateMeshAxisAssignmentForLoopIterators(
+          meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+          meshAxisAssignmentForLoopIterators);
+    }
+  }
+
+  ShardingArray res;
+  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
+                  [](std::optional<SmallVector<MeshAxis>> &axes) {
+                    if (!axes) {
+                      return SmallVector<MeshAxis>();
+                    };
+                    return std::move(*axes);
+                  });
+  return res;
+}
+
+bool mesh::isAtLeastOneReductionIteratorSharded(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+  for (auto [loopIteratorType, meshAxisAssignment] :
+       llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    if (loopIteratorType == utils::IteratorType::reduction &&
+        !meshAxisAssignment.empty()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+  SmallVector<MeshAxis> meshAxes;
+  for (auto [loopIteratorType, meshAxisAssignment] :
+       llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    if (loopIteratorType == utils::IteratorType::reduction) {
+      meshAxes.insert(meshAxes.end(), meshAxisAssignment.begin(),
----------------
antiagainst wrote:

`append_range`?

https://github.com/llvm/llvm-project/pull/82284


More information about the llvm-commits mailing list