[Mlir-commits] [mlir] [mlir][shard, mpi] Fixing lowering allgather shard->mpi->llvm (PR #178870)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 30 03:29:34 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

`shard.allgather` concatenates along a specified gather-axis. However, `mpi.allgather` always concatenates along the first dimension and there is no MPI operation that allows gathering along an arbitrary axis. Hence, if gather-axis!=0, we need to create a temporary buffer where we gather along the first dimension and then copy from that buffer to the final output along the specified gather-axis. This is not ideal, but the best I could do with reasonable effort.

Along the way also
- fixing computation of memref size in mpitollvm
- adding a simple canonicalization pattern for comm_size for easier debugging
- adding more tests

---

Patch is 43.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178870.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+1) 
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+25-13) 
- (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+90-6) 
- (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+38-20) 
- (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+27-15) 
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+93-7) 
- (added) mlir/test/Dialect/MPI/canonicalize.mlir (+14) 
- (modified) mlir/test/Dialect/Shard/partition.mlir (+54-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index b24c12459475e..d9e47ea3f6bfe 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -96,6 +96,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
   );
 
   let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 4a1c5d1f7846c..9ddef319ad71c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -51,7 +51,8 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
 
 std::pair<Value, Value> getRawPtrAndSize(const Location loc,
                                          ConversionPatternRewriter &rewriter,
-                                         Value memRef, Type elType) {
+                                         Value memRef, int64_t rank,
+                                         Type elType) {
   Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
   Value dataPtr =
       LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
@@ -59,11 +60,16 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
                                               rewriter.getI64Type(), memRef, 2);
   Value resPtr =
       LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
-  Value size;
+  Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+                                        rewriter.getIndexAttr(1));
   if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
-    size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
-                                        ArrayRef<int64_t>{3, 0});
-    size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
+    for (int64_t i = 0; i < rank; ++i) {
+      Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+                                               ArrayRef<int64_t>{3, i});
+      dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
+      size =
+          LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
+    }
   } else {
     size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
   }
@@ -396,7 +402,7 @@ class OMPIImplTraits : public MPIImplTraits {
 };
 
 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
-  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
+  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, false);
   if (failed(attr))
     return std::make_unique<MPICHImplTraits>(moduleOp);
   auto strAttr = dyn_cast<StringAttr>(attr.value());
@@ -634,7 +640,7 @@ struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
         LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
     // get or create function declaration:
     LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
-        moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+        moduleOp, loc, rewriter, "MPI_Comm_size", SizeFuncType);
 
     // replace with function call
     auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
@@ -675,6 +681,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
     Type elemType = op.getRef().getType().getElementType();
+    int64_t rank = op.getRef().getType().getRank();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -684,7 +691,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
 
     // get MPI_COMM_WORLD, dataType and pointer
     auto [dataPtr, size] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -727,6 +734,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
     Type elemType = op.getRef().getType().getElementType();
+    int64_t rank = op.getRef().getType().getRank();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -736,7 +744,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
 
     // get MPI_COMM_WORLD, dataType, status_ignore and pointer
     auto [dataPtr, size] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -782,10 +790,12 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
     MLIRContext *context = rewriter.getContext();
     Type sElemType = op.getSendbuf().getType().getElementType();
     Type rElemType = op.getRecvbuf().getType().getElementType();
+    int64_t sRank = op.getSendbuf().getType().getRank();
+    int64_t rRank = op.getRecvbuf().getType().getRank();
     auto [sendPtr, sendSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
     auto [recvPtr, recvSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
 
     auto moduleOp = op->getParentOfType<ModuleOp>();
     auto mpiTraits = MPIImplTraits::get(moduleOp);
@@ -843,15 +853,17 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
     Type elemType = op.getSendbuf().getType().getElementType();
+    int64_t sRank = op.getSendbuf().getType().getRank();
+    int64_t rRank = op.getRecvbuf().getType().getRank();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
     auto moduleOp = op->getParentOfType<ModuleOp>();
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     auto [sendPtr, sendSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
     auto [recvPtr, recvSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
 
     // If input and output are the same, request in-place operation.
     if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 87ae28892fcf7..c765ad5a579c8 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -620,6 +621,13 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
 struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
   using CommOpPattern::CommOpPattern;
 
+  // shard.allgather concatenates along a specified gather-axis.
+  // mpi.allgather always concatenates along the first dimension and
+  // there is no MPI operation that allows gathering along an arbitrary axis.
+  // Hence, if gather-axis!=0, we need to create a temporary buffer
+  // where we gather along the first dimension and then copy from that
+  // buffer to the final output along the specified gather-axis.
+
   LogicalResult
   matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -637,19 +645,95 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
     if (!memref::isStaticShapeAndContiguousRowMajor(outType))
       return op.emitError(
           "Expected static shaped memref in contiguous row-major layout.");
+    int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
+    auto ctx = op->getContext();
 
     // Get the right communicator
     Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
-    // Allocate output buffer
-    Value output = memref::AllocOp::create(iBuilder, outType);
+
+    Value nRanks =
+        mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
+            .getSize();
+    nRanks =
+        arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
+
+    Value tmpOutput, gatherDimSz;
+    if (gatherAxis == 0) {
+      tmpOutput = memref::AllocOp::create(iBuilder, outType);
+    } else {
+      // MPI's allgather always concatenates along the first dimension.
+      // Create a memref type for the output buffer with adjusted (expanded)
+      // shape.
+      SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
+      llvm::append_range(gatherShape, outType.getShape());
+      gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
+      MemRefType gatherType =
+          MemRefType::get(gatherShape, outType.getElementType());
+      gatherDimSz = arith::ConstantIndexOp::create(
+          iBuilder, outType.getDimSize(gatherAxis));
+      gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
+                                           gatherDimSz, nRanks);
+      // Allocate output buffer
+      tmpOutput =
+          memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
+    }
     // Create the MPI AllGather operation.
-    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
+
+    // If gather-axis!=0, copy from gathered buffer to output with the right
+    // layout.
+    Value finalOutput = tmpOutput;
+    if (gatherAxis != 0) {
+      int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
+      assert(nSrcDims == outType.getRank() + 1 &&
+             "Expected gathered type to have rank one more than output type.");
+
+      // Create affine map for copying from gathered buffer to output.
+      SmallVector<AffineExpr> dims;
+      dims.reserve(nSrcDims);
+      for (unsigned i = 0; i < nSrcDims; ++i)
+        dims.emplace_back(getAffineDimExpr(i, ctx));
+      AffineExpr s = getAffineSymbolExpr(0, ctx);
+      SmallVector<AffineExpr> results;
+      results.reserve(nSrcDims);
+      for (unsigned i = 0; i < nSrcDims - 1; ++i) {
+        if (i == gatherAxis)
+          results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
+        else
+          results.emplace_back(dims[i + 1]);
+      }
+      auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
+
+      finalOutput = memref::AllocOp::create(iBuilder, outType);
+
+      // Now build a loop nest to copy from gathered buffer to finalOutput
+      // It would be nicer to just use a memref.transpose/collapse_shape op but
+      // these currently only support simpler cases.
+      Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
+      SmallVector<Value> lbs(nSrcDims, zero);
+      SmallVector<Value> ubs;
+      for (int64_t d = 0; d < nSrcDims; ++d)
+        ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
+      SmallVector<int64_t> steps(nSrcDims, 1);
+      auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
+        Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
+        // set symbol value
+        SmallVector<Value> ivss(ivs.begin(), ivs.end());
+        ivss.emplace_back(gatherDimSz);
+        affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
+                                      ivss);
+      };
+      affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
+                                  emitCopy);
+
+      memref::DeallocOp::create(iBuilder, tmpOutput);
+    }
 
     // If the destination is a tensor, cast it to a tensor
     if (isa<RankedTensorType>(op.getType()))
-      output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
-                                                 true);
-    rewriter.replaceOp(op, output);
+      finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
+                                                      finalOutput, true);
+    rewriter.replaceOp(op, finalOutput);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index f52c3f99189d2..e47a1f62d76d0 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -43,33 +43,46 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
   }
 };
 
+template <typename OpT>
+static LogicalResult FoldToDLTIConst(OpT op, const char *key,
+                                     mlir::PatternRewriter &b) {
+  auto comm = op.getComm();
+  if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
+    return mlir::failure();
+
+  // Try to get DLTI attribute for MPI:comm_world_rank
+  // If found, set worldRank to the value of the attribute.
+  auto dltiAttr = dlti::query(op, {key}, false);
+  if (failed(dltiAttr))
+    return mlir::failure();
+  if (!isa<IntegerAttr>(dltiAttr.value()))
+    return op->emitError() << "Expected an integer attribute for " << key;
+  Value res = arith::ConstantOp::create(
+      b, op.getLoc(), b.getI32Type(),
+      b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
+  if (Value retVal = op.getRetval())
+    b.replaceOp(op, {retVal, res});
+  else
+    b.replaceOp(op, res);
+  return mlir::success();
+}
+
 struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
   using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
-
   LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
                                 mlir::PatternRewriter &b) const override {
-    auto comm = op.getComm();
-    if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
-      return mlir::failure();
-
-    // Try to get DLTI attribute for MPI:comm_world_rank
-    // If found, set worldRank to the value of the attribute.
-    auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
-    if (failed(dltiAttr))
-      return mlir::failure();
-    if (!isa<IntegerAttr>(dltiAttr.value()))
-      return op->emitError()
-             << "Expected an integer attribute for MPI:comm_world_rank";
-    Value res = arith::ConstantIndexOp::create(
-        b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
-    if (Value retVal = op.getRetval())
-      b.replaceOp(op, {retVal, res});
-    else
-      b.replaceOp(op, res);
-    return mlir::success();
+    return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
   }
 };
 
+struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
+  using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
+                                mlir::PatternRewriter &b) const override {
+    return FoldToDLTIConst(op, "MPI:comm_world_size", b);
+  }
+};
 } // namespace
 
 void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -97,6 +110,11 @@ void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
   results.add<FoldRank>(context);
 }
 
+void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldSize>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 4c1beee2fe144..c6d4ca1eb2e33 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -4,7 +4,7 @@
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
@@ -40,7 +40,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
     // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -50,7 +51,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
     // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -60,7 +62,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
     // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -72,7 +75,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
     // C...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/178870


More information about the Mlir-commits mailing list