[Mlir-commits] [mlir] 7a4c497 - [mlir][mesh] Use one type for mesh axis (#76830)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 3 15:47:15 PST 2024


Author: Boian Petkantchin
Date: 2024-01-03T15:47:11-08:00
New Revision: 7a4c49756db161ebcce08c7bc860a569aad7f276

URL: https://github.com/llvm/llvm-project/commit/7a4c49756db161ebcce08c7bc860a569aad7f276
DIFF: https://github.com/llvm/llvm-project/commit/7a4c49756db161ebcce08c7bc860a569aad7f276.diff

LOG: [mlir][mesh] Use one type for mesh axis (#76830)

Make all ops and attributes use the types MeshAxis and MeshAxesAttr
instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
    mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a9d30dfbb9a76e..060d54b82efa63 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -80,8 +80,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let parameters = (ins
     AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
-    OptionalArrayRefParameter<"int32_t">:$partial_axes,
+    ArrayRefParameter<"MeshAxesAttr">:$split_axes,
+    OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
   );
 
@@ -146,18 +146,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let builders = [
     AttrBuilder<(ins "SymbolRefAttr":$cluster, 
-                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
-                     "ArrayRef<int32_t>": $partial_axes,
+                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
+                     "ArrayRef<MeshAxis>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
-      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
-                  split_axes, [&](ArrayRef<int32_t> array) {
-          return DenseI32ArrayAttr::get($_ctxt, array);
+      SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
+                  split_axes, [&](ArrayRef<MeshAxis> array) {
+          return MeshAxesAttr::get($_ctxt, array);
       });
       return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
     AttrBuilder<(ins "SymbolRefAttr":$cluster, 
-                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
       return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
     }]>
   ];

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index ce7d5d045122d9..83452dcc2e8abe 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -17,6 +17,15 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include <algorithm>
 
+namespace mlir {
+namespace mesh {
+
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
+} // namespace mesh
+} // namespace mlir
+
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
@@ -30,9 +39,6 @@
 namespace mlir {
 namespace mesh {
 
-using MeshAxis = int16_t;
-using MeshAxesAttr = DenseI16ArrayAttr;
-
 bool isReductionLoop(IteratorType iType);
 
 bool areReductionAndPartialMatch(IteratorType iType, Partial partial);

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1ed54b6519e4d8..1934bdfb427059 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -114,7 +114,7 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMeth
 
   let builders = [
     OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
 
@@ -228,7 +228,7 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
   }];
   let builders = [
     OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 270955a3036e89..201c0151754eba 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -18,8 +18,8 @@ class Operation;
 
 namespace mesh {
 
-using ShardingArray = SmallVector<SmallVector<int32_t>>;
-using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
+using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
 
 struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index de4f58d54e8ca5..c3d8f1d456106d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,15 +266,15 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
-                         ArrayRef<int32_t> partialAxes, Partial) {
+                         SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+                         ArrayRef<MeshAxis> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // 
diff iculty in fetching the corresponding symbol op based on an attribute.
 
-  llvm::SmallSet<int32_t, 4> visitedAxes;
+  llvm::SmallSet<MeshAxis, 4> visitedAxes;
 
-  auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
-    for (int32_t axis : axesArray) {
+  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
+    for (MeshAxis axis : axesArray) {
       if (axis < 0)
         return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
@@ -283,8 +283,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   };
 
-  for (DenseI32ArrayAttr subAxes : splitAxes) {
-    ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
+  for (MeshAxesAttr subAxes : splitAxes) {
+    ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }
@@ -318,10 +318,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
   return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
                                        getSplitAxes().end()),
-                      std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+                      std::mem_fn(&MeshAxesAttr::empty)) &&
          llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
                                        rhs.getSplitAxes().end()),
-                      std::mem_fn(&DenseI32ArrayAttr::empty));
+                      std::mem_fn(&MeshAxesAttr::empty));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index a6f2f435f36d68..ee885ab16b7b06 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -216,7 +216,7 @@ namespace {
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
                                         SymbolRefAttr cluster,
-                                        ArrayRef<int32_t> meshAxes,
+                                        ArrayRef<MeshAxis> meshAxes,
                                         unsigned loopIdx) {
   if ((shardingOption.cluster && cluster &&
        shardingOption.cluster != cluster) ||
@@ -230,7 +230,7 @@ static LogicalResult fillShardingOption(Operation *op,
     if (i == loopIdx)
       continue;
 
-    for (int32_t axis : meshAxes) {
+    for (MeshAxis axis : meshAxes) {
       if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
         LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
                           << axis << " duplicate");
@@ -260,7 +260,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
   shardingOption.shardingArray.resize(loopTypes.size());
-  llvm::SmallVector<int32_t> partialMeshAxes;
+  llvm::SmallVector<MeshAxis> partialMeshAxes;
   Partial partialType;
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
@@ -277,7 +277,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     // shardingOption[index]
     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
-      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
       auto dim = cast<AffineDimExpr>(expr);
       unsigned index = dim.getPosition();
       visitedLoopIndices.insert(index);
@@ -288,7 +288,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
     // Handle the partial axes: at this stage, the exact loop index/indices
     // cannot be decided because there could be multiple reduction loops.
-    ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
+    ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
     if (!partialAxes.empty()) {
       if (!partialMeshAxes.empty())
         return op->emitOpError() << "at most one result with partial axes is "
@@ -321,7 +321,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     // then the operands with multiple loop indices.
     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
-      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
       FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
           checkOperandAffineExpr(expr, numDims);
       if (failed(loopIndices))
@@ -362,7 +362,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
   if (!partialMeshAxes.empty()) {
     bool anyNonEmptyReductionLoop = llvm::any_of(
         llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
-          SmallVector<int32_t> &subArray = it.value();
+          SmallVector<MeshAxis> &subArray = it.value();
           int64_t idx = it.index();
           return isReductionLoop(loopTypes[idx]) && !subArray.empty();
         });
@@ -406,8 +406,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
     return success();
 
   auto resultType = result.getType().cast<RankedTensorType>();
-  SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
-  SmallVector<int32_t> partialAxes;
+  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+  SmallVector<MeshAxis> partialAxes;
 
   // process the split axes
   for (auto it : llvm::enumerate(map.getResults())) {
@@ -431,7 +431,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
         assert(partialType == curPartialType &&
                "Only one reduction type is supported");
       partialType = curPartialType;
-      const SmallVector<int32_t> &axis = std::get<1>(it);
+      const SmallVector<MeshAxis> &axis = std::get<1>(it);
       partialAxes.append(axis);
     }
   }
@@ -459,7 +459,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
     return success();
   Value operand = opOperand.get();
   auto operandType = operand.getType().cast<RankedTensorType>();
-  SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
     int64_t idx = it.index();

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8d7e89662131a0..37b86535959652 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -147,7 +147,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
           .getResult()
           .cast<TypedValue<ShapedType>>();
 
-  llvm::SmallVector<int32_t> remainingPartialAxes;
+  llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
                 std::back_inserter(allReduceMeshAxes),
                 [&targetShardingPartialAxesSet](Axis a) {
@@ -163,17 +163,17 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
 static MeshShardingAttr
 targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
                               int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
          splitTensorAxis) {
-    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+    targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
   }
   auto targetSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
   targetSplitAxes.push_back(splitMeshAxis);
   targetShardingSplitAxes[splitTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+      MeshAxesAttr::get(ctx, targetSplitAxes);
   return MeshShardingAttr::get(
       ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -356,7 +356,7 @@ static MeshShardingAttr
 targetShardingInUnsplitLastAxis(MLIRContext *ctx,
                                 MeshShardingAttr sourceSharding,
                                 int64_t splitTensorAxis) {
-  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
          splitTensorAxis);
@@ -365,7 +365,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
 
   targetSplitAxes.pop_back();
   targetShardingSplitAxes[splitTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+      MeshAxesAttr::get(ctx, targetSplitAxes);
   return MeshShardingAttr::get(
       ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -475,11 +475,11 @@ static MeshShardingAttr
 targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
                              int64_t sourceTensorAxis,
                              int64_t targetTensorAxis) {
-  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
          targetTensorAxis) {
-    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+    targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
   }
 
   auto sourceSplitAxes =
@@ -488,13 +488,13 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
   auto meshAxis = sourceSplitAxes.back();
   sourceSplitAxes.pop_back();
   targetShardingSplitAxes[sourceTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+      MeshAxesAttr::get(ctx, sourceSplitAxes);
 
   auto targetSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
   targetSplitAxes.push_back(meshAxis);
   targetShardingSplitAxes[targetTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+      MeshAxesAttr::get(ctx, targetSplitAxes);
 
   return MeshShardingAttr::get(
       ctx, sourceSharding.getCluster(), targetShardingSplitAxes,


        


More information about the Mlir-commits mailing list