[Mlir-commits] [mlir] [mlir][mesh, mpi] Lower allreduce (PR #144060)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 05:17:42 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
Adding lowering mesh.allreduce to mpi.allreduce.
Minor restructuring to increase code reuse.
---
Patch is 63.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144060.diff
12 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+2)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.h (+1)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+6-4)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+2-2)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+5-5)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+4)
- (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+132-39)
- (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+38)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+25)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+16-6)
- (modified) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+207-150)
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+30)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..5a864865adffc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
shard/partition sizes depend on the rank.
}];
let dependentDialects = [
+ "affine::AffineDialect",
+ "arith::ArithDialect",
"memref::MemRefDialect",
"mpi::MPIDialect",
"scf::SCFDialect",
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
index f06b911ce3fe3..2b6743cd008c6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.h
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
@@ -12,6 +12,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
// MPIDialect
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d78aa92d201e7..c14837f6961eb 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/MPI/IR/MPI.td"
include "mlir/Dialect/MPI/IR/MPITypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
class MPI_Op<string mnemonic, list<Trait> traits = []>
: Op<MPI_Dialect, mnemonic, traits>;
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
// CommWorldOp
//===----------------------------------------------------------------------===//
-def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
let description = [{
This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
// CommRankOp
//===----------------------------------------------------------------------===//
-def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
);
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// CommSizeOp
//===----------------------------------------------------------------------===//
-def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
// CommSplitOp
//===----------------------------------------------------------------------===//
-def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f59c4c4c67517..ac05ee243d7be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
```
}];
let arguments = !con(commonArgs, (ins
- AnyRankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
- AnyRankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index c64da29ca6412..3f1041cb25103 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
auto isEndomorphismOp = [reduction](Operation *op,
std::optional<Operation *> referenceOp) {
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
- if (!allReduceOp ||
- allReduceOp.getInput().getType().getElementType() !=
- allReduceOp.getResult().getType().getElementType() ||
+ auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
+ auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
+ if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
allReduceOp.getReduction() != reduction) {
return false;
}
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}
auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+ auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
- allReduceOp.getInput().getType().getElementType() ==
- refAllReduceOp.getInput().getType().getElementType();
+ inType.getElementType() == refType.getElementType();
};
auto isAlgebraicOp = [](Operation *op) {
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index be82e2af399dc..5a1154bf9166e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -42,6 +42,10 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder);
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 823d4d644f586..521569e69b61a 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -12,9 +12,9 @@
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +22,8 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp
class ConvertProcessLinearIndexOp
: public OpConversionPattern<ProcessLinearIndexOp> {
- int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
public:
using OpConversionPattern::OpConversionPattern;
- // Constructor accepting worldRank
- ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
- MLIRContext *context, int64_t worldRank = -1)
- : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
-
LogicalResult
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
+ // Create mpi::CommRankOp
Location loc = op.getLoc();
- if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
- return success();
- }
-
- // Otherwise call create mpi::CommRankOp
auto ctx = op.getContext();
Value commWorld =
rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
@@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
}
};
+static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
+ auto ctx = kind.getContext();
+ switch (kind.getValue()) {
+ case ReductionKind::Sum:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
+ case ReductionKind::Product:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
+ case ReductionKind::Min:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
+ case ReductionKind::Max:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
+ case ReductionKind::BitwiseAnd:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
+ case ReductionKind::BitwiseOr:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
+ case ReductionKind::BitwiseXor:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
+ default:
+ assert(false && "Unknown/unsupported reduction kind");
+ }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ auto mesh = adaptor.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ if (!meshOp)
+ return op->emitError() << "No mesh found for AllReduceOp";
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return op->emitError()
+ << "Dynamic mesh shape not supported in AllReduceOp";
+
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = adaptor.getInput();
+ auto inputShape = cast<ShapedType>(input.getType()).getShape();
+
+ // If the source is a memref, cast it to a tensor.
+ if (isa<RankedTensorType>(input.getType())) {
+ auto memrefType = MemRefType::get(
+ inputShape, cast<ShapedType>(input.getType()).getElementType());
+ input = iBuilder.create<bufferization::ToMemrefOp>(memrefType, input);
+ }
+ MemRefType inType = cast<MemRefType>(input.getType());
+
+ // Get the actual shape to allocate the buffer.
+ SmallVector<OpFoldResult> shape(inType.getRank());
+ for (auto i = 0; i < inType.getRank(); ++i) {
+ auto s = inputShape[i];
+ if (ShapedType::isDynamic(s))
+ shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+ else
+ shape[i] = iBuilder.getIndexAttr(s);
+ }
+
+ // Allocate buffer and copy input to buffer.
+ Value buffer = iBuilder.create<memref::AllocOp>(
+ shape, cast<ShapedType>(op.getType()).getElementType());
+ iBuilder.create<linalg::CopyOp>(input, buffer);
+
+ // Get an MPI_Comm_split for the AllReduce operation.
+ // The color is the linear index of the process in the mesh along the
+ // non-reduced axes. The key is the linear index of the process in the mesh
+ // along the reduced axes.
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ iBuilder.getIndexType());
+ SmallVector<Value> myMultiIndex =
+ iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+ .getResult();
+ Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+ SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+
+ auto redAxes = adaptor.getMeshAxes();
+ for (auto axis : redAxes) {
+ multiKey[axis] = myMultiIndex[axis];
+ myMultiIndex[axis] = zero;
+ }
+
+ Value color =
+ createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+ color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+ Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+ key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+
+ // Finally split the communicator
+ auto commType = mpi::CommType::get(op->getContext());
+ Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+ auto comm =
+ iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+ .getNewcomm();
+
+ Value buffer1d = buffer;
+ // Collapse shape to 1d if needed
+ if (inType.getRank() > 1) {
+ ReassociationIndices reassociation(inType.getRank());
+ std::iota(reassociation.begin(), reassociation.end(), 0);
+ buffer1d = iBuilder.create<memref::CollapseShapeOp>(
+ buffer, ArrayRef<ReassociationIndices>(reassociation));
+ }
+
+ // Create the MPI AllReduce operation.
+ iBuilder.create<mpi::AllReduceOp>(
+ TypeRange(), buffer1d, buffer1d,
+ getMPIReduction(adaptor.getReductionAttr()), comm);
+
+ // If the destination is a memref, cast it to a tensor
+ if (isa<RankedTensorType>(op.getType()))
+ buffer = iBuilder.create<bufferization::ToTensorOp>(buffer, true);
+
+ rewriter.replaceOp(op, buffer);
+ return success();
+ }
+};
+
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
@@ -573,10 +681,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
Value array = dest;
if (isa<RankedTensorType>(array.getType())) {
// If the destination is a memref, we need to cast it to a tensor
- auto tensorType = MemRefType::get(
+ auto mmemrefType = MemRefType::get(
dstShape, cast<ShapedType>(array.getType()).getElementType());
array =
- rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array);
+ rewriter.create<bufferization::ToMemrefOp>(loc, mmemrefType, array);
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -753,22 +861,6 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
- uint64_t worldRank = -1;
- // Try to get DLTI attribute for MPI:comm_world_rank
- // If found, set worldRank to the value of the attribute.
- {
- auto dltiAttr =
- dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
- if (succeeded(dltiAttr)) {
- if (!isa<IntegerAttr>(dltiAttr.value())) {
- getOperation()->emitError()
- << "Expected an integer attribute for MPI:comm_world_rank";
- return signalPassFailure();
- }
- worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
- }
- }
-
auto *ctxt = &getContext();
RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
@@ -819,10 +911,10 @@ struct ConvertMeshToMPIPass
// ...except the global MeshOp
target.addLegalOp<mesh::MeshOp>();
// Allow all the stuff that our patterns will convert to
- target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
- arith::ArithDialect, tensor::TensorDialect,
- bufferization::BufferizationDialect,
- linalg::LinalgDialect, memref::MemRefDialect>();
+ target.addLegalDialect<
+ BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
+ tensor::TensorDialect, bufferization::BufferizationDialect,
+ linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
// Make sure the function signature, calls etc. are legal
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType());
@@ -832,9 +924,10 @@ struct ConvertMeshToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertProcessMultiIndexOp, ConvertGetShardingOp,
- ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
- // ConvertProcessLinearIndexOp accepts an optional worldRank
- patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+ ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
+ ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+ SymbolTableCollection symbolTableCollection;
+ mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 56d8edfbcc025..6d445ca0e4099 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
@@ -41,6 +42,38 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
return mlir::success();
}
};
+
+struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
+ using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
+ mlir::PatternRewriter &b) const override {
+ auto comm = op.getComm();
+ if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>()) {
+ return mlir::failure();
+ }
+
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ {
+ auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
+ if (failed(dltiAttr))
+ return mlir::failure();
+ if (!isa<IntegerAttr>(dltiAttr.value())) {
+ return op->emitError()
+ << "Expected an integer attribute for MPI:comm_world_rank";
+ }
+ Value res = b.create<arith::ConstantIndexOp>(
+ op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+ if (Value retVal = op.getRetval())
+ b.replaceOp(op, {retVal, res});
+ else
+ b.replaceOp(op, res);
+ return mlir::success();
+ }
+ }
+};
+
} // namespace
void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -63,6 +96,11 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
}
+void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+ results.add<FoldRank>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304cb55a35086..b84de2b716b32 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -75,6 +75,31 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
return lhs.value() * rhs.value();
}
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
+SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
+ const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics,
+ Type ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144060
More information about the Mlir-commits
mailing list