[Mlir-commits] [mlir] [MLIR][OpenMP] Simplify OpenMP to LLVM dialect conversion (PR #132009)
Sergio Afonso
llvmlistbot at llvm.org
Thu Mar 20 04:13:01 PDT 2025
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/132009
>From 4c5df0326096578367962c1cb92ac32a521d336a Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 19 Mar 2025 11:29:09 +0000
Subject: [PATCH 1/2] [MLIR][OpenMP] Simplify OpenMP to LLVM dialect conversion
This patch makes a few changes to unify the conversion process from the 'omp'
to the 'llvm' dialect. The main goal of this change is to consolidate the logic
used to identify legal and illegal ops, and to consolidate the conversion logic
into a single class.
Changes introduced are the following:
- Removal of `getNumVariableOperands()` and `getVariableOperand()` extra
class declarations from OpenMP operations. These are redundant, as they are
equivalent to `mlir::Operation::getNumOperands()` and
`mlir::Operation::getOperands()`, respectively.
- Consolidation of `RegionOpConversion`,
`RegionLessOpWithVarOperandsConversion`, `RegionOpWithVarOperandsConversion`,
`RegionLessOpConversion`, `AtomicReadOpConversion`, `MapInfoOpConversion`,
`DeclMapperOpConversion` and `MultiRegionOpConversion` into a single
`OpenMPOpConversion` class. This is possible because all of the previous were
doing parts of the same set of operations based on whether they defined any
regions, whether they took operands, type attributes, etc.
- Update of `mlir::configureOpenMPToLLVMConversionLegality` to use a single
generic set of checks for all operations, removing the need to list every
operation manually.
---
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 92 -----
.../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 345 +++++-------------
2 files changed, 95 insertions(+), 342 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 2c2ecdd225f4a..63b38835d133f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -891,17 +891,6 @@ def FlushOp : OpenMP_Op<"flush", clauses = [
// Override inherited assembly format to include `varList`.
let assemblyFormat = "( `(` $varList^ `:` type($varList) `)` )? attr-dict";
-
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- return getOperation()->getNumOperands();
- }
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- return getOperand(i);
- }
- }] # clausesExtraClassDeclaration;
}
//===----------------------------------------------------------------------===//
@@ -1001,18 +990,6 @@ def MapBoundsOp : OpenMP_Op<"map.bounds",
) attr-dict
}];
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- return getNumOperands();
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- return getOperands()[i];
- }
- }];
-
let hasVerifier = 1;
}
@@ -1098,18 +1075,6 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
| `bounds` `(` $bounds `)`
) `->` type($omp_ptr) attr-dict
}];
-
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- return getNumOperands();
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- return getOperands()[i];
- }
- }];
}
//===---------------------------------------------------------------------===//
@@ -1515,21 +1480,6 @@ def AtomicReadOp : OpenMP_Op<"atomic.read", traits = [
clausesOptAssemblyFormat #
") `:` type($v) `,` type($x) `,` $element_type attr-dict";
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getX() && "expected 'x' operand");
- assert(getV() && "expected 'v' operand");
- return 2;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i < 2 && "invalid index position for an operand");
- return i == 0 ? getX() : getV();
- }
- }] # clausesExtraClassDeclaration;
-
let hasVerifier = 1;
}
@@ -1555,21 +1505,6 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write", traits = [
let assemblyFormat = "$x `=` $expr" # clausesReqAssemblyFormat # " oilist(" #
clausesOptAssemblyFormat # ") `:` type($x) `,` type($expr) attr-dict";
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getX() && "expected address operand");
- assert(getExpr() && "expected value operand");
- return 2;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i < 2 && "invalid index position for an operand");
- return i == 0 ? getX() : getExpr();
- }
- }] # clausesExtraClassDeclaration;
-
let hasVerifier = 1;
}
@@ -1614,20 +1549,6 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update", traits = [
let assemblyFormat = clausesAssemblyFormat #
"$x `:` type($x) $region attr-dict";
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getX() && "expected 'x' operand");
- return 1;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i == 0 && "invalid index position for an operand");
- return getX();
- }
- }] # clausesExtraClassDeclaration;
-
let hasVerifier = 1;
let hasRegionVerifier = 1;
let hasCanonicalizeMethod = 1;
@@ -1715,19 +1636,6 @@ def ThreadprivateOp : OpenMP_Op<"threadprivate",
let assemblyFormat = [{
$sym_addr `:` type($sym_addr) `->` type($tls_addr) attr-dict
}];
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getSymAddr() && "expected one variable operand");
- return 1;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i == 0 && "invalid index position for an operand");
- return getSymAddr();
- }
- }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 7888745dc6920..218260bd53a1e 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -28,262 +28,101 @@ namespace mlir {
using namespace mlir;
namespace {
-/// A pattern that converts the region arguments in a single-region OpenMP
-/// operation to the LLVM dialect. The body of the region is not modified and is
-/// expected to either be processed by the conversion infrastructure or already
-/// contain ops compatible with LLVM dialect types.
-template <typename OpType>
-struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
- using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto newOp = rewriter.create<OpType>(
- curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
- rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
- *this->getTypeConverter())))
- return failure();
-
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-template <typename T>
-struct RegionLessOpWithVarOperandsConversion
- : public ConvertOpToLLVMPattern<T> {
+/// A pattern that converts the result and operand types, attributes, and region
+/// arguments of an OpenMP operation to the LLVM dialect.
+///
+/// Attributes are copied verbatim by default, and only translated if they are
+/// type attributes.
+///
+/// Region bodies, if any, are not modified and expected to either be processed
+/// by the conversion infrastructure or already contain ops compatible with LLVM
+/// dialect types.
+template <typename T, bool SupportsMemRefOperand = true>
+struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(T curOp, typename T::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
- return failure();
- SmallVector<Value> convertedOperands;
- assert(curOp.getNumVariableOperands() ==
- curOp.getOperation()->getNumOperands() &&
- "unexpected non-variable operands");
- for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
- Value originalVariableOperand = curOp.getVariableOperand(idx);
- if (!originalVariableOperand)
- return failure();
- if (isa<MemRefType>(originalVariableOperand.getType())) {
- // TODO: Support memref type in variable operands
- return rewriter.notifyMatchFailure(curOp,
- "memref is not supported yet");
- }
- convertedOperands.emplace_back(adaptor.getOperands()[idx]);
- }
- rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
- curOp->getAttrs());
- return success();
- }
-};
-
-template <typename T>
-struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
- using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(T curOp, typename T::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
- return failure();
- SmallVector<Value> convertedOperands;
- assert(curOp.getNumVariableOperands() ==
- curOp.getOperation()->getNumOperands() &&
- "unexpected non-variable operands");
- for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
- Value originalVariableOperand = curOp.getVariableOperand(idx);
- if (!originalVariableOperand)
- return failure();
- if (isa<MemRefType>(originalVariableOperand.getType())) {
- // TODO: Support memref type in variable operands
- return rewriter.notifyMatchFailure(curOp,
- "memref is not supported yet");
- }
- convertedOperands.emplace_back(adaptor.getOperands()[idx]);
- }
- auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
- curOp->getAttrs());
- rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
- *this->getTypeConverter())))
- return failure();
-
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-template <typename T>
-struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
- using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(T curOp, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Translate result types.
const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
return failure();
- rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
- curOp->getAttrs());
- return success();
- }
-};
-
-struct AtomicReadOpConversion
- : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
- using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- Type curElementType = curOp.getElementType();
- auto newOp = rewriter.create<omp::AtomicReadOp>(
- curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
- TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
- newOp.setElementTypeAttr(typeAttr);
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
- using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-
- SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
- return failure();
-
- // Copy attributes of the curOp except for the typeAttr which should
- // be converted
- SmallVector<NamedAttribute> newAttrs;
+ // Translate type attributes.
+ // They are kept unmodified except if they are type attributes.
+ SmallVector<NamedAttribute> convertedAttrs;
for (NamedAttribute attr : curOp->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
- Type newAttr = converter->convertType(typeAttr.getValue());
- newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+ Type convertedType = converter->convertType(typeAttr.getValue());
+ convertedAttrs.emplace_back(attr.getName(),
+ TypeAttr::get(convertedType));
} else {
- newAttrs.push_back(attr);
+ convertedAttrs.push_back(attr);
}
}
- rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
- curOp, resTypes, adaptor.getOperands(), newAttrs);
- return success();
- }
-};
-
-struct DeclMapperOpConversion
- : public ConvertOpToLLVMPattern<omp::DeclareMapperOp> {
- using ConvertOpToLLVMPattern<omp::DeclareMapperOp>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- SmallVector<NamedAttribute> newAttrs;
- newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr());
- newAttrs.emplace_back(
- curOp.getTypeAttrName(),
- TypeAttr::get(converter->convertType(curOp.getType())));
-
- auto newOp = rewriter.create<omp::DeclareMapperOp>(
- curOp.getLoc(), TypeRange(), adaptor.getOperands(), newAttrs);
- rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
- *this->getTypeConverter())))
- return failure();
-
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-template <typename OpType>
-struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
- using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
-
- void forwardOpAttrs(OpType curOp, OpType newOp) const {}
+ // Translate operands.
+ SmallVector<Value> convertedOperands;
+ convertedOperands.reserve(curOp->getNumOperands());
+ for (auto [originalOperand, convertedOperand] :
+ llvm::zip_equal(curOp->getOperands(), adaptor.getOperands())) {
+ if (!originalOperand)
+ return failure();
- LogicalResult
- matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto newOp = rewriter.create<OpType>(
- curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
- TypeAttr::get(this->getTypeConverter()->convertType(
- curOp.getTypeAttr().getValue())));
- forwardOpAttrs(curOp, newOp);
+ if constexpr (!SupportsMemRefOperand) {
+ if (isa<MemRefType>(originalOperand.getType())) {
+ // TODO: Support memref type in variable operands
+ return rewriter.notifyMatchFailure(curOp,
+ "memref is not supported yet");
+ }
+ }
+ convertedOperands.push_back(convertedOperand);
+ }
- for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
- rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
- newOp.getRegion(idx).end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
+ // Create new operation.
+ auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
+ convertedAttrs);
+
+ // Translate regions.
+ for (auto [originalRegion, convertedRegion] :
+ llvm::zip_equal(curOp->getRegions(), newOp->getRegions())) {
+ rewriter.inlineRegionBefore(originalRegion, convertedRegion,
+ convertedRegion.end());
+ if (failed(rewriter.convertRegionTypes(&convertedRegion,
*this->getTypeConverter())))
return failure();
}
- rewriter.eraseOp(curOp);
+ // Delete old operation and replace result uses with those of the new one.
+ rewriter.replaceOp(curOp, newOp->getResults());
return success();
}
};
-template <>
-void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
- omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
- newOp.setDataSharingType(curOp.getDataSharingType());
-}
} // namespace
void mlir::configureOpenMPToLLVMConversionLegality(
ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
target.addDynamicallyLegalOp<
- omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
- omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp,
- omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp,
- omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
- omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>(
- [&](Operation *op) {
- return typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes());
- });
- target.addDynamicallyLegalOp<
- omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp,
- omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp,
- omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp,
- omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp,
- omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp,
- omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
- omp::WsloopOp>([&](Operation *op) {
- return std::all_of(op->getRegions().begin(), op->getRegions().end(),
+#define GET_OP_LIST
+#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
+ >([&](Operation *op) {
+ return typeConverter.isLegal(op->getOperandTypes()) &&
+ typeConverter.isLegal(op->getResultTypes()) &&
+ std::all_of(op->getRegions().begin(), op->getRegions().end(),
[&](Region ®ion) {
return typeConverter.isLegal(®ion);
}) &&
- typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes());
+ std::all_of(op->getAttrs().begin(), op->getAttrs().end(),
+ [&](NamedAttribute attr) {
+ auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
+ return !typeAttr ||
+ typeConverter.isLegal(typeAttr.getValue());
+ });
});
- target.addDynamicallyLegalOp<omp::PrivateClauseOp>(
- [&](omp::PrivateClauseOp op) -> bool {
- return std::all_of(op->getRegions().begin(), op->getRegions().end(),
- [&](Region ®ion) {
- return typeConverter.isLegal(®ion);
- }) &&
- typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes()) &&
- typeConverter.isLegal(op.getType());
- });
}
void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
@@ -295,36 +134,42 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
[&](omp::MapBoundsType type) -> Type { return type; });
patterns.add<
- AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion,
- MultiRegionOpConversion<omp::DeclareReductionOp>,
- MultiRegionOpConversion<omp::PrivateClauseOp>,
- RegionLessOpConversion<omp::CancellationPointOp>,
- RegionLessOpConversion<omp::CancelOp>,
- RegionLessOpConversion<omp::CriticalDeclareOp>,
- RegionLessOpConversion<omp::DeclareMapperInfoOp>,
- RegionLessOpConversion<omp::OrderedOp>,
- RegionLessOpConversion<omp::ScanOp>,
- RegionLessOpConversion<omp::TargetEnterDataOp>,
- RegionLessOpConversion<omp::TargetExitDataOp>,
- RegionLessOpConversion<omp::TargetUpdateOp>,
- RegionLessOpConversion<omp::YieldOp>,
- RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
- RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
- RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
- RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
- RegionOpConversion<omp::AtomicCaptureOp>,
- RegionOpConversion<omp::CriticalOp>,
- RegionOpConversion<omp::DistributeOp>,
- RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
- RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
- RegionOpConversion<omp::OrderedRegionOp>,
- RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
- RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
- RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
- RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
- RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
- RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
- RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
+ OpenMPOpConversion<omp::AtomicCaptureOp>,
+ OpenMPOpConversion<omp::AtomicReadOp>,
+ OpenMPOpConversion<omp::AtomicUpdateOp, /*SupportsMemRefOperand=*/false>,
+ OpenMPOpConversion<omp::AtomicWriteOp, /*SupportsMemRefOperand=*/false>,
+ OpenMPOpConversion<omp::BarrierOp>,
+ OpenMPOpConversion<omp::CancellationPointOp>,
+ OpenMPOpConversion<omp::CancelOp>,
+ OpenMPOpConversion<omp::CriticalDeclareOp>,
+ OpenMPOpConversion<omp::CriticalOp>,
+ OpenMPOpConversion<omp::DeclareMapperInfoOp>,
+ OpenMPOpConversion<omp::DeclareMapperOp>,
+ OpenMPOpConversion<omp::DeclareReductionOp>,
+ OpenMPOpConversion<omp::DistributeOp>,
+ OpenMPOpConversion<omp::FlushOp, /*SupportsMemRefOperand=*/false>,
+ OpenMPOpConversion<omp::LoopNestOp>, OpenMPOpConversion<omp::LoopOp>,
+ OpenMPOpConversion<omp::MapBoundsOp, /*SupportsMemRefOperand=*/false>,
+ OpenMPOpConversion<omp::MapInfoOp>, OpenMPOpConversion<omp::MaskedOp>,
+ OpenMPOpConversion<omp::MasterOp>, OpenMPOpConversion<omp::OrderedOp>,
+ OpenMPOpConversion<omp::OrderedRegionOp>,
+ OpenMPOpConversion<omp::ParallelOp>,
+ OpenMPOpConversion<omp::PrivateClauseOp>, OpenMPOpConversion<omp::ScanOp>,
+ OpenMPOpConversion<omp::SectionOp>, OpenMPOpConversion<omp::SectionsOp>,
+ OpenMPOpConversion<omp::SimdOp>, OpenMPOpConversion<omp::SingleOp>,
+ OpenMPOpConversion<omp::TargetDataOp>,
+ OpenMPOpConversion<omp::TargetEnterDataOp>,
+ OpenMPOpConversion<omp::TargetExitDataOp>,
+ OpenMPOpConversion<omp::TargetOp>,
+ OpenMPOpConversion<omp::TargetUpdateOp>,
+ OpenMPOpConversion<omp::TaskgroupOp>, OpenMPOpConversion<omp::TaskloopOp>,
+ OpenMPOpConversion<omp::TaskOp>, OpenMPOpConversion<omp::TaskwaitOp>,
+ OpenMPOpConversion<omp::TaskyieldOp>, OpenMPOpConversion<omp::TeamsOp>,
+ OpenMPOpConversion<omp::TerminatorOp>,
+ OpenMPOpConversion<omp::ThreadprivateOp, /*SupportsMemRefOperand=*/false>,
+ OpenMPOpConversion<omp::WorkshareLoopWrapperOp>,
+ OpenMPOpConversion<omp::WorkshareOp>, OpenMPOpConversion<omp::WsloopOp>,
+ OpenMPOpConversion<omp::YieldOp>>(converter);
}
namespace {
>From 2bdd87af8632ac1ba8ccc4e5719d3436306b5b26 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 20 Mar 2025 11:12:46 +0000
Subject: [PATCH 2/2] Remove the need to list operations manually
---
.../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 78 +++++++------------
1 file changed, 29 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 218260bd53a1e..6119097456d1f 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -38,23 +38,23 @@ namespace {
/// Region bodies, if any, are not modified and expected to either be processed
/// by the conversion infrastructure or already contain ops compatible with LLVM
/// dialect types.
-template <typename T, bool SupportsMemRefOperand = true>
+template <typename T>
struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(T curOp, typename T::Adaptor adaptor,
+ matchAndRewrite(T op, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Translate result types.
const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
+ if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
return failure();
// Translate type attributes.
// They are kept unmodified except if they are type attributes.
SmallVector<NamedAttribute> convertedAttrs;
- for (NamedAttribute attr : curOp->getAttrs()) {
+ for (NamedAttribute attr : op->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
Type convertedType = converter->convertType(typeAttr.getValue());
convertedAttrs.emplace_back(attr.getName(),
@@ -66,29 +66,32 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
// Translate operands.
SmallVector<Value> convertedOperands;
- convertedOperands.reserve(curOp->getNumOperands());
+ convertedOperands.reserve(op->getNumOperands());
for (auto [originalOperand, convertedOperand] :
- llvm::zip_equal(curOp->getOperands(), adaptor.getOperands())) {
+ llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
if (!originalOperand)
return failure();
- if constexpr (!SupportsMemRefOperand) {
+ // TODO: Revisit whether we need to trigger an error specifically for this
+ // set of operations. Consider removing this check or updating the list.
+ if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
+ omp::FlushOp, omp::MapBoundsOp,
+ omp::ThreadprivateOp>::value) {
if (isa<MemRefType>(originalOperand.getType())) {
// TODO: Support memref type in variable operands
- return rewriter.notifyMatchFailure(curOp,
- "memref is not supported yet");
+ return rewriter.notifyMatchFailure(op, "memref is not supported yet");
}
}
convertedOperands.push_back(convertedOperand);
}
// Create new operation.
- auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
+ auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
convertedAttrs);
// Translate regions.
for (auto [originalRegion, convertedRegion] :
- llvm::zip_equal(curOp->getRegions(), newOp->getRegions())) {
+ llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
rewriter.inlineRegionBefore(originalRegion, convertedRegion,
convertedRegion.end());
if (failed(rewriter.convertRegionTypes(&convertedRegion,
@@ -97,7 +100,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
}
// Delete old operation and replace result uses with those of the new one.
- rewriter.replaceOp(curOp, newOp->getResults());
+ rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
@@ -125,6 +128,15 @@ void mlir::configureOpenMPToLLVMConversionLegality(
});
}
+/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
+/// passed as template argument.
+template <typename... Ts>
+static inline RewritePatternSet &
+addOpenMPOpConversions(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ return patterns.add<OpenMPOpConversion<Ts>...>(converter);
+}
+
void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// This type is allowed when converting OpenMP to LLVM Dialect, it carries
@@ -133,43 +145,11 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
converter.addConversion(
[&](omp::MapBoundsType type) -> Type { return type; });
- patterns.add<
- OpenMPOpConversion<omp::AtomicCaptureOp>,
- OpenMPOpConversion<omp::AtomicReadOp>,
- OpenMPOpConversion<omp::AtomicUpdateOp, /*SupportsMemRefOperand=*/false>,
- OpenMPOpConversion<omp::AtomicWriteOp, /*SupportsMemRefOperand=*/false>,
- OpenMPOpConversion<omp::BarrierOp>,
- OpenMPOpConversion<omp::CancellationPointOp>,
- OpenMPOpConversion<omp::CancelOp>,
- OpenMPOpConversion<omp::CriticalDeclareOp>,
- OpenMPOpConversion<omp::CriticalOp>,
- OpenMPOpConversion<omp::DeclareMapperInfoOp>,
- OpenMPOpConversion<omp::DeclareMapperOp>,
- OpenMPOpConversion<omp::DeclareReductionOp>,
- OpenMPOpConversion<omp::DistributeOp>,
- OpenMPOpConversion<omp::FlushOp, /*SupportsMemRefOperand=*/false>,
- OpenMPOpConversion<omp::LoopNestOp>, OpenMPOpConversion<omp::LoopOp>,
- OpenMPOpConversion<omp::MapBoundsOp, /*SupportsMemRefOperand=*/false>,
- OpenMPOpConversion<omp::MapInfoOp>, OpenMPOpConversion<omp::MaskedOp>,
- OpenMPOpConversion<omp::MasterOp>, OpenMPOpConversion<omp::OrderedOp>,
- OpenMPOpConversion<omp::OrderedRegionOp>,
- OpenMPOpConversion<omp::ParallelOp>,
- OpenMPOpConversion<omp::PrivateClauseOp>, OpenMPOpConversion<omp::ScanOp>,
- OpenMPOpConversion<omp::SectionOp>, OpenMPOpConversion<omp::SectionsOp>,
- OpenMPOpConversion<omp::SimdOp>, OpenMPOpConversion<omp::SingleOp>,
- OpenMPOpConversion<omp::TargetDataOp>,
- OpenMPOpConversion<omp::TargetEnterDataOp>,
- OpenMPOpConversion<omp::TargetExitDataOp>,
- OpenMPOpConversion<omp::TargetOp>,
- OpenMPOpConversion<omp::TargetUpdateOp>,
- OpenMPOpConversion<omp::TaskgroupOp>, OpenMPOpConversion<omp::TaskloopOp>,
- OpenMPOpConversion<omp::TaskOp>, OpenMPOpConversion<omp::TaskwaitOp>,
- OpenMPOpConversion<omp::TaskyieldOp>, OpenMPOpConversion<omp::TeamsOp>,
- OpenMPOpConversion<omp::TerminatorOp>,
- OpenMPOpConversion<omp::ThreadprivateOp, /*SupportsMemRefOperand=*/false>,
- OpenMPOpConversion<omp::WorkshareLoopWrapperOp>,
- OpenMPOpConversion<omp::WorkshareOp>, OpenMPOpConversion<omp::WsloopOp>,
- OpenMPOpConversion<omp::YieldOp>>(converter);
+ // Add conversions for all OpenMP operations.
+ addOpenMPOpConversions<
+#define GET_OP_LIST
+#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
+ >(converter, patterns);
}
namespace {
More information about the Mlir-commits
mailing list