[Mlir-commits] [mlir] More on MeshToMPI (PR #129048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 05:10:32 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

- do not create MPI operations if no halo exchange is needed on a given boundary
- allow returning sharding information by return `!mesh.sharding` (gets converted into a tuple of tensors)
- lowering `mesh.shard_shape` including fixes to the operation itself
- global symbol `static_mpi_rank` replaced by an DLTI attribute (now aligned with MPIToLLVM)
- smaller fixes and some minor cleanup

---

Patch is 79.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129048.diff


10 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+4-4) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+15-7) 
- (modified) mlir/lib/Conversion/MeshToMPI/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+510-102) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+11-4) 
- (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+10-9) 
- (modified) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+211-116) 
- (added) mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir (+75) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+6-4) 
- (modified) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+7-4) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let description = [{
     This pass converts communication operations from the Mesh dialect to the
     MPI dialect.
-    If it finds a global named "static_mpi_rank" it will use that splat value
-    instead of calling MPI_Comm_rank. This allows optimizations like constant
-    shape propagation and fusion because shard/partition sizes depend on the
-    rank.
+    If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+    use that integer value instead of calling MPI_Comm_rank. This allows
+    optimizations like constant shape propagation and fusion because
+    shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
   }];
 }
 
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
-  let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+    Pure, AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Get the shard shape for a given process/device.";
   let description = [{
-    The device/process id is a linearized id of the device/process in the mesh.
+    The device/process id is a multi-index of the device/process in the mesh.
     This operation might be used during spmdization when the shard shape depends
     on (non-constant) values used in `mesh.sharding`.
   }];
   let arguments = (ins
-    DenseI64ArrayAttr:$shape,
+    DenseI64ArrayAttr:$dims,
+    Variadic<Index>:$dims_dynamic,
     Mesh_Sharding:$sharding,
-    Index:$device
+    DenseI64ArrayAttr:$device,
+    Variadic<Index>:$device_dynamic
   );
   let results = (outs Variadic<Index>:$result);
   let assemblyFormat = [{
-      custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+      `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+      `sharding` `=` $sharding
+      `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+      attr-dict `:` type(results)
   }];
   let builders = [
-    OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+    OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
   ];
 }
 
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
   Core
 
   LINK_LIBS PUBLIC
+  MLIRDLTIDialect
   MLIRFuncDialect
   MLIRIR
   MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..84db6d456711c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,13 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +30,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
 
 namespace {
-// Create operations converting a linear index to a multi-dimensional index
+/// Convert vec of OpFoldResults (ints) into vector of Values.
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+                                           llvm::ArrayRef<int64_t> statics,
+                                           ValueRange dynamics,
+                                           Type type = Type()) {
+  SmallVector<Value> values;
+  auto dyn = dynamics.begin();
+  Type i64 = b.getI64Type();
+  if (!type)
+    type = i64;
+  assert(i64 == type || b.getIndexType() == type);
+  for (auto s : statics) {
+    values.emplace_back(
+        ShapedType::isDynamic(s)
+            ? *(dyn++)
+            : b.create<arith::ConstantOp>(loc, type,
+                                          i64 == type ? b.getI64IntegerAttr(s)
+                                                      : b.getIndexAttr(s)));
+  }
+  return values;
+};
+
+/// Create operations converting a linear index to a multi-dimensional index.
 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
                                              Value linearIndex,
                                              ValueRange dimensions) {
@@ -48,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
 
   for (int i = n - 1; i >= 0; --i) {
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if (i > 0) {
+    if (i > 0)
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
-    }
   }
 
   return multiIndex;
 }
 
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
                          ValueRange dimensions) {
 
-  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
-  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
     linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
@@ -72,35 +99,179 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   return linearIndex;
 }
 
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp)
+      return failure();
+    auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+    if (!shardingOp)
+      return failure();
+
+    rewriter.replaceOp(op, shardingOp.getResult());
+    return success();
+  }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+///   (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto splitAxes = op.getSplitAxes().getAxes();
+    int64_t maxNAxes = 0;
+    for (auto axes : splitAxes) {
+      maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+    }
+
+    // To hold the split axes, create empty 2d tensor with shape
+    // {splitAxes.size(), max-size-of-split-groups}.
+    // Set trailing elements for smaller split-groups to -1.
+    Location loc = op.getLoc();
+    auto i16 = rewriter.getI16Type();
+    auto i64 = rewriter.getI64Type();
+    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+    auto attr = IntegerAttr::get(i16, 0xffff);
+    Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+                       .getResult(0);
+
+    // explicitly write values into tensor row by row
+    int64_t strides[] = {1, 1};
+    int64_t nSplits = 0;
+    ValueRange empty = {};
+    for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+      int64_t size = axes.size();
+      if (size > 0)
+        ++nSplits;
+      int64_t offs[] = {(int64_t)i, 0};
+      int64_t sizes[] = {1, size};
+      auto tensorType = RankedTensorType::get({size}, i16);
+      auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+      auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+      resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+          loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+    }
+
+    // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+    // Store the halo sizes in the tensor.
+    auto haloSizes =
+        getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+                         adaptor.getDynamicHaloSizes());
+    auto type = RankedTensorType::get({nSplits, 2}, i64);
+    Value resHaloSizes =
+        haloSizes.empty()
+            ? rewriter
+                  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+                                           i64)
+                  .getResult()
+            : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+                  .getResult();
+
+    // To hold sharded dims offsets, create Tensor with shape {nSplits,
+    // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+    // elements for smaller split-groups to -1. Computing the max size of the
+    // split groups needs using collectiveProcessGroupSize (which needs the
+    // MeshOp)
+    Value resOffsets;
+    if (adaptor.getStaticShardedDimsOffsets().empty()) {
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{0, 0}, i64);
+    } else {
+      SymbolTableCollection symbolTableCollection;
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto maxSplitSize = 0;
+      for (auto axes : splitAxes) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic);
+        maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+      }
+      assert(maxSplitSize);
+      ++maxSplitSize; // add one for the total size
+
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+      Value zero = rewriter.create<arith::ConstantOp>(
+          loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+      resOffsets =
+          rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+      auto offsets =
+          getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+                           adaptor.getDynamicShardedDimsOffsets());
+      int64_t curr = 0;
+      for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+        ++splitSize; // add one for the total size
+        ArrayRef<Value> values(&offsets[curr], splitSize);
+        Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+        int64_t offs[] = {(int64_t)i, 0};
+        int64_t sizes[] = {1, splitSize};
+        resOffsets = rewriter.create<tensor::InsertSliceOp>(
+            loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+        curr += splitSize;
+      }
+    }
+
+    // return a tuple of tensors as defined by type converter
+    SmallVector<Type> resTypes;
+    if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+                                               resTypes)))
+      return failure();
+
+    resSplitAxes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+    resHaloSizes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+    resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TupleType::get(op.getContext(), resTypes),
+        ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+    return success();
+  }
+};
+
 struct ConvertProcessMultiIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public OpConversionPattern<ProcessMultiIndexOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Currently converts its linear index to a multi-dimensional index.
 
     SymbolTableCollection symbolTableCollection;
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto meshOp = getMesh(op, symbolTableCollection);
     // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape())) {
-      return mlir::failure();
-    }
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
 
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto rank =
-        rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+    Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
     // optionally extract subset of mesh axes
-    auto axes = op.getAxes();
+    auto axes = adaptor.getAxes();
     if (!axes.empty()) {
       SmallVector<Value> subIndex;
       for (auto axis : axes) {
@@ -110,32 +281,33 @@ struct ConvertProcessMultiIndexOp
     }
 
     rewriter.replaceOp(op, mIdx);
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertProcessLinearIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-
-    // Finds a global named "static_mpi_rank" it will use that splat value.
-    // Otherwise it defaults to mpi.comm_rank.
-
-    auto loc = op.getLoc();
-    auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
-    if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
-            op, rankOpName)) {
-      if (auto initTnsr = globalOp.getInitialValueAttr()) {
-        auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
-        rewriter.replaceOp(op,
-                           rewriter.create<arith::ConstantIndexOp>(loc, val));
-        return mlir::success();
-      }
+class ConvertProcessLinearIndexOp
+    : public OpConversionPattern<ProcessLinearIndexOp> {
+  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
+
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  // Constructor accepting worldRank
+  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+                              MLIRContext *context, int64_t worldRank_ = -1)
+      : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+
+  LogicalResult
+  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+      return success();
     }
+
+    // Otherwise call create mpi::CommRankOp
     auto rank =
         rewriter
             .create<mpi::CommRankOp>(
@@ -144,44 +316,43 @@ struct ConvertProcessLinearIndexOp
             .getRank();
     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
                                                     rank);
-    return mlir::success();
+    return success();
   }
 };
 
 struct ConvertNeighborsLinearIndicesOp
-    : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public OpConversionPattern<NeighborsLinearIndicesOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Computes the neighbors indices along a split axis by simply
     // adding/subtracting 1 to the current index in that dimension.
     // Assigns -1 if neighbor is out of bounds.
 
-    auto axes = op.getSplitAxes();
+    auto axes = adaptor.getSplitAxes();
     // For now only single axis sharding is supported
-    if (axes.size() != 1) {
-      return mlir::failure();
-    }
+    if (axes.size() != 1)
+      return failure();
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto mIdx = op.getDevice();
+    auto mIdx = adaptor.getDevice();
     auto orgIdx = mIdx[axes[0]];
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto dimSz = dims[axes[0]];
-    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
-    auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
-    auto atBorder = rewriter.create<arith::CmpIOp>(
+    Value dimSz = dims[axes[0]];
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+    Value atBorder = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sle, orgIdx,
-        rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
     auto down = rewriter.create<scf::IfOp>(
         loc, atBorder,
         [&](OpBuilder &builder, Location loc) {
@@ -206,23 +377,161 @@ struct ConvertNeighborsLinearIndicesOp
         [&](OpBuilder &builder, Location loc) {
           SmallVector<Value> tmp = mIdx;
           tmp[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
-                  .getResult();
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
           builder.create<scf::YieldOp>(
               loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
-    return mlir::success();
+    return success();
   }
 };
 
-...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/129048


More information about the Mlir-commits mailing list