[Mlir-commits] [mlir] [mlir][shard, mpi] Fixing lowering allgather shard->mpi->llvm (PR #178870)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 30 03:29:34 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
`shard.allgather` concatenates along a specified gather-axis. However, `mpi.allgather` always concatenates along the first dimension and there is no MPI operation that allows gathering along an arbitrary axis. Hence, if gather-axis!=0, we need to create a temporary buffer where we gather along the first dimension and then copy from that buffer to the final output along the specified gather-axis. This is not ideal, but the best I could do with reasonable effort.
Along the way also
- fixing computation of memref size in mpitollvm
- adding a simple canonicalization pattern for comm_size for easier debugging
- adding more tests
---
Patch is 43.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178870.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+1)
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+25-13)
- (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+90-6)
- (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+38-20)
- (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+27-15)
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+93-7)
- (added) mlir/test/Dialect/MPI/canonicalize.mlir (+14)
- (modified) mlir/test/Dialect/Shard/partition.mlir (+54-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index b24c12459475e..d9e47ea3f6bfe 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -96,6 +96,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
);
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 4a1c5d1f7846c..9ddef319ad71c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -51,7 +51,8 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
std::pair<Value, Value> getRawPtrAndSize(const Location loc,
ConversionPatternRewriter &rewriter,
- Value memRef, Type elType) {
+ Value memRef, int64_t rank,
+ Type elType) {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value dataPtr =
LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
@@ -59,11 +60,16 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
rewriter.getI64Type(), memRef, 2);
Value resPtr =
LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
- Value size;
+ Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIndexAttr(1));
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
- size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
+ for (int64_t i = 0; i < rank; ++i) {
+ Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+ ArrayRef<int64_t>{3, i});
+ dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
+ size =
+ LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
+ }
} else {
size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
}
@@ -396,7 +402,7 @@ class OMPIImplTraits : public MPIImplTraits {
};
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
- auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
+ auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, false);
if (failed(attr))
return std::make_unique<MPICHImplTraits>(moduleOp);
auto strAttr = dyn_cast<StringAttr>(attr.value());
@@ -634,7 +640,7 @@ struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
- moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+ moduleOp, loc, rewriter, "MPI_Comm_size", SizeFuncType);
// replace with function call
auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
@@ -675,6 +681,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type elemType = op.getRef().getType().getElementType();
+ int64_t rank = op.getRef().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -684,7 +691,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
// get MPI_COMM_WORLD, dataType and pointer
auto [dataPtr, size] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -727,6 +734,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Type elemType = op.getRef().getType().getElementType();
+ int64_t rank = op.getRef().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -736,7 +744,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
auto [dataPtr, size] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -782,10 +790,12 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
MLIRContext *context = rewriter.getContext();
Type sElemType = op.getSendbuf().getType().getElementType();
Type rElemType = op.getRecvbuf().getType().getElementType();
+ int64_t sRank = op.getSendbuf().getType().getRank();
+ int64_t rRank = op.getRecvbuf().getType().getRank();
auto [sendPtr, sendSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
auto [recvPtr, recvSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
@@ -843,15 +853,17 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Type elemType = op.getSendbuf().getType().getElementType();
+ int64_t sRank = op.getSendbuf().getType().getRank();
+ int64_t rRank = op.getRecvbuf().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
auto [sendPtr, sendSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
auto [recvPtr, recvSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 87ae28892fcf7..c765ad5a579c8 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -30,6 +30,7 @@
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -620,6 +621,13 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
using CommOpPattern::CommOpPattern;
+ // shard.allgather concatenates along a specified gather-axis.
+ // mpi.allgather always concatenates along the first dimension and
+ // there is no MPI operation that allows gathering along an arbitrary axis.
+ // Hence, if gather-axis!=0, we need to create a temporary buffer
+ // where we gather along the first dimension and then copy from that
+ // buffer to the final output along the specified gather-axis.
+
LogicalResult
matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -637,19 +645,95 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");
+ int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
+ auto ctx = op->getContext();
// Get the right communicator
Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
- // Allocate output buffer
- Value output = memref::AllocOp::create(iBuilder, outType);
+
+ Value nRanks =
+ mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
+ .getSize();
+ nRanks =
+ arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
+
+ Value tmpOutput, gatherDimSz;
+ if (gatherAxis == 0) {
+ tmpOutput = memref::AllocOp::create(iBuilder, outType);
+ } else {
+ // MPI's allgather always concatenates along the first dimension.
+ // Create a memref type for the output buffer with adjusted (expanded)
+ // shape.
+ SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
+ llvm::append_range(gatherShape, outType.getShape());
+ gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
+ MemRefType gatherType =
+ MemRefType::get(gatherShape, outType.getElementType());
+ gatherDimSz = arith::ConstantIndexOp::create(
+ iBuilder, outType.getDimSize(gatherAxis));
+ gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
+ gatherDimSz, nRanks);
+ // Allocate output buffer
+ tmpOutput =
+ memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
+ }
// Create the MPI AllGather operation.
- mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+ mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
+
+ // If gather-axis!=0, copy from gathered buffer to output with the right
+ // layout.
+ Value finalOutput = tmpOutput;
+ if (gatherAxis != 0) {
+ int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
+ assert(nSrcDims == outType.getRank() + 1 &&
+ "Expected gathered type to have rank one more than output type.");
+
+ // Create affine map for copying from gathered buffer to output.
+ SmallVector<AffineExpr> dims;
+ dims.reserve(nSrcDims);
+ for (unsigned i = 0; i < nSrcDims; ++i)
+ dims.emplace_back(getAffineDimExpr(i, ctx));
+ AffineExpr s = getAffineSymbolExpr(0, ctx);
+ SmallVector<AffineExpr> results;
+ results.reserve(nSrcDims);
+ for (unsigned i = 0; i < nSrcDims - 1; ++i) {
+ if (i == gatherAxis)
+ results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
+ else
+ results.emplace_back(dims[i + 1]);
+ }
+ auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
+
+ finalOutput = memref::AllocOp::create(iBuilder, outType);
+
+ // Now build a loop nest to copy from gathered buffer to finalOutput
+ // It would be nicer to just use a memref.transpose/collapse_shape op but
+ // these currently only support simpler cases.
+ Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
+ SmallVector<Value> lbs(nSrcDims, zero);
+ SmallVector<Value> ubs;
+ for (int64_t d = 0; d < nSrcDims; ++d)
+ ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
+ SmallVector<int64_t> steps(nSrcDims, 1);
+ auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
+ Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
+ // set symbol value
+ SmallVector<Value> ivss(ivs.begin(), ivs.end());
+ ivss.emplace_back(gatherDimSz);
+ affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
+ ivss);
+ };
+ affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
+ emitCopy);
+
+ memref::DeallocOp::create(iBuilder, tmpOutput);
+ }
// If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
- output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
- true);
- rewriter.replaceOp(op, output);
+ finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
+ finalOutput, true);
+ rewriter.replaceOp(op, finalOutput);
return success();
}
};
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index f52c3f99189d2..e47a1f62d76d0 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -43,33 +43,46 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
}
};
+template <typename OpT>
+static LogicalResult FoldToDLTIConst(OpT op, const char *key,
+ mlir::PatternRewriter &b) {
+ auto comm = op.getComm();
+ if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
+ return mlir::failure();
+
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ auto dltiAttr = dlti::query(op, {key}, false);
+ if (failed(dltiAttr))
+ return mlir::failure();
+ if (!isa<IntegerAttr>(dltiAttr.value()))
+ return op->emitError() << "Expected an integer attribute for " << key;
+ Value res = arith::ConstantOp::create(
+ b, op.getLoc(), b.getI32Type(),
+ b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
+ if (Value retVal = op.getRetval())
+ b.replaceOp(op, {retVal, res});
+ else
+ b.replaceOp(op, res);
+ return mlir::success();
+}
+
struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
-
LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
mlir::PatternRewriter &b) const override {
- auto comm = op.getComm();
- if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
- return mlir::failure();
-
- // Try to get DLTI attribute for MPI:comm_world_rank
- // If found, set worldRank to the value of the attribute.
- auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
- if (failed(dltiAttr))
- return mlir::failure();
- if (!isa<IntegerAttr>(dltiAttr.value()))
- return op->emitError()
- << "Expected an integer attribute for MPI:comm_world_rank";
- Value res = arith::ConstantIndexOp::create(
- b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
- if (Value retVal = op.getRetval())
- b.replaceOp(op, {retVal, res});
- else
- b.replaceOp(op, res);
- return mlir::success();
+ return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
}
};
+struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
+ using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
+ mlir::PatternRewriter &b) const override {
+ return FoldToDLTIConst(op, "MPI:comm_world_size", b);
+ }
+};
} // namespace
void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -97,6 +110,11 @@ void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
results.add<FoldRank>(context);
}
+void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+ results.add<FoldSize>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 4c1beee2fe144..c6d4ca1eb2e33 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -4,7 +4,7 @@
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
@@ -40,7 +40,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+ // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
+ // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -50,7 +51,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+ // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
+ // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -60,7 +62,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+ // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
+ // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -72,7 +75,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+ // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
+ // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
// C...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/178870
More information about the Mlir-commits
mailing list