[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 8 18:59:09 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the in-group process index.

Make resharding generate an all-scatter instead of inserting the slicing logic directly.

---

Patch is 55.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81218.diff


26 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+7) 
- (modified) mlir/include/mlir/Dialect/Arith/Utils/Utils.h (+17) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+6) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+63) 
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+18-1) 
- (modified) mlir/include/mlir/IR/Builders.h (+3) 
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+27) 
- (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+19) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+61-3) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+2) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+5-11) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+8-76) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+125-11) 
- (added) mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h (+35) 
- (modified) mlir/lib/IR/Builders.cpp (+18) 
- (added) mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir (+74) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+13) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+52) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+24) 
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+6-25) 
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+6-8) 
- (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+1-1) 
- (added) mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp (+79) 
- (removed) mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp (-54) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2) 
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
 
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
 class DominanceInfo;
 class Operation;
 class PostDominanceInfo;
+class ImplicitLocOpBuilder;
 
 namespace func {
 class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                  ArrayRef<OpFoldResult> shape,
+                                  ImplicitLocOpBuilder &builder);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
 
@@ -81,6 +82,22 @@ struct ArithBuilder {
   OpBuilder &b;
   Location loc;
 };
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType);
+} // namespace arith
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+                                    mesh.getShape());
+}
+
 // Get the size of a sharded dimension.
 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
   if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
 
   let builders = [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+    Pure,
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+  let description = [{
+    Scatter along the `scatter_axis` tensor axis.
+    This operation can be taught of as the inverse of all-gather.
+    Technically, it is not required that all processes have the same input tensor.
+    Each process will slice a piece of its local tensor based on its in-group device index.
+    The operation does not communicate data between devices. 
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh0(shape = 2x2)
+    ...
+    %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+      : tensor<2x4xi8> -> tensor<2x2xi8>
+    ```
+    Input:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+    Result:
+    ```
+    gather tensor
+    axis 1
+    ------------>
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$scatter_axis
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+  ];
+}
+
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     Pure,
     SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
 class RewritePatternSet;
 class SymbolTableCollection;
 class DialectRegistry;
+class ImplicitLocOpBuilder;
 namespace mesh {
 
 void processMultiIndexOpLoweringPopulatePatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
 void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
 
+void allScatterOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry &registry);
+
+void populateAllOpLoweringPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry &registry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+                                 ImplicitLocOpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
+  // Returns a 1-valued attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  TypedAttr getOneAttr(Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   results.push_back(residual);
   return results;
 }
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                   ArrayRef<OpFoldResult> shape,
+                                   ImplicitLocOpBuilder &builder) {
+  assert(multiIndex.size() == shape.size());
+  SmallVector<AffineExpr> shapeAffine;
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+  }
+
+  SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(stridesAffine.size());
+  llvm::transform(stridesAffine, std::back_inserter(strides),
+                  [&builder, &shape](AffineExpr strideExpr) {
+                    return affine::makeComposedFoldedAffineApply(
+                        builder, builder.getLoc(), strideExpr, shape);
+                  });
+
+  auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+      OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+  return affine::makeComposedFoldedAffineApply(
+      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+  return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType) {
+  Value one = builder.create<ConstantOp>(loc, resultType,
+                                         builder.getOneAttr(resultType));
+  ArithBuilder arithBuilder(builder, loc);
+  return std::accumulate(
+      values.begin(), values.end(), one,
+      [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         MeshOp mesh) {
+  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        MeshOp mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
-        mesh.getSymName(),
-        MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+        SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+                          odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         StringRef mesh, ArrayRef<MeshAxis> axes) {
+  assert(!axes.empty());
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
   return success();
 }
 
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+                                          ArrayRef<MeshAxis> meshAxes,
+                                          int64_t scatterAxis) {
+  RankedTensorType operandRankedTensorType =
+      operandType.cast<RankedTensorType>();
+  DimensionSize operandScatterAxisSize =
+      operandRankedTensorType.getShape()[scatterAxis];
+  SmallVector<int64_t> resultShape =
+      llvm::to_vector(operandRankedTensorType.getShape());
+
+  resultShape[scatterAxis] =
+      operandScatterAxisSize /
+      DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+  return operandRankedTensorType.clone(resultShape);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_gather op
 //===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                               MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+                         int64_t scatterAxis) {
+  Type resultType =
+      scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+        scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Type resultType, Value input, StringRef mesh,
+                         ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+        APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRControlFlowDialect
   MLIRFuncDialect
   MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include <iterator>
 #include <numeric>
 #include <utility>
 
@@ -56,13 +56,10 @@ namespace {
 // symbol tables.
 // We can't use DialectFoldInterface since the cache may be invalidated by some
 // pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
-  template <typename... OpRewritePatternArgs>
-  MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
-                  OpRewritePatternArgs &&...opRewritePatternArgs)
-      : OpRewritePattern(
-            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+    : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
   LogicalResult matchAndRewrite(MeshShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
 
     return success();
   }
-
-private:
-  SymbolTableCollection &symbolTableCollection;
 };
 
 } // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
-                                             int64_t splitTensorAxis,
-                                             int64_t splitCount) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
-  targetShape[splitTensorAxis] =
-      shardDimension(targetShape[splitTensorAxis], splitCount);
-  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
 // Split a replicated tensor along a mesh axis.
 // e.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshShardingAttr sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  MLIRContext *ctx = builder.getContext();
-  builder.setInsertionPointAfterValue(sourceShard);
-
-  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
-  Value processIndexAlongAxis =
+  TypedValue<ShapedType> targetShard =
       builder
-          .create<ProcessMultiIndexOp>(mesh.getSymName(),
-                                       SmallVector<MeshAxis>({splitMeshAxis}))
-          .getResult()[0];
-
+          .create<AllScatterOp>(sourceShard, mesh,
+                                ArrayRef<MeshAxis>(splitMeshAxis),
+                                splitTensorAxis)
+          .ge...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81218


More information about the Mlir-commits mailing list