[Mlir-commits] [mlir] [MLIR][OpenMP] Simplify OpenMP to LLVM dialect conversion (PR #132009)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 19 04:48:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch makes a few changes to unify the conversion process from the 'omp' to the 'llvm' dialect. The main goal of this change is to consolidate the logic used to identify legal and illegal ops, and to consolidate the conversion logic into a single class.
Changes introduced are the following:
- Removal of `getNumVariableOperands()` and `getVariableOperand()` extra class declarations from OpenMP operations. These are redundant, as they are equivalent to `mlir::Operation::getNumOperands()` and `mlir::Operation::getOperands()`, respectively.
- Consolidation of `RegionOpConversion`, `RegionLessOpWithVarOperandsConversion`, `RegionOpWithVarOperandsConversion`, `RegionLessOpConversion`, `AtomicReadOpConversion`, `MapInfoOpConversion`, `DeclMapperOpConversion` and `MultiRegionOpConversion` into a single `OpenMPOpConversion` class. This is possible because all of the previous were doing parts of the same set of operations based on whether they defined any regions, whether they took operands, type attributes, etc.
- Update of `mlir::configureOpenMPToLLVMConversionLegality` to use a single generic set of checks for all operations, removing the need to list every operation manually.
---
Patch is 23.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132009.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (-92)
- (modified) mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp (+95-250)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 2c2ecdd225f4a..63b38835d133f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -891,17 +891,6 @@ def FlushOp : OpenMP_Op<"flush", clauses = [
// Override inherited assembly format to include `varList`.
let assemblyFormat = "( `(` $varList^ `:` type($varList) `)` )? attr-dict";
-
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- return getOperation()->getNumOperands();
- }
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- return getOperand(i);
- }
- }] # clausesExtraClassDeclaration;
}
//===----------------------------------------------------------------------===//
@@ -1001,18 +990,6 @@ def MapBoundsOp : OpenMP_Op<"map.bounds",
) attr-dict
}];
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- return getNumOperands();
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- return getOperands()[i];
- }
- }];
-
let hasVerifier = 1;
}
@@ -1098,18 +1075,6 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
| `bounds` `(` $bounds `)`
) `->` type($omp_ptr) attr-dict
}];
-
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- return getNumOperands();
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- return getOperands()[i];
- }
- }];
}
//===---------------------------------------------------------------------===//
@@ -1515,21 +1480,6 @@ def AtomicReadOp : OpenMP_Op<"atomic.read", traits = [
clausesOptAssemblyFormat #
") `:` type($v) `,` type($x) `,` $element_type attr-dict";
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getX() && "expected 'x' operand");
- assert(getV() && "expected 'v' operand");
- return 2;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i < 2 && "invalid index position for an operand");
- return i == 0 ? getX() : getV();
- }
- }] # clausesExtraClassDeclaration;
-
let hasVerifier = 1;
}
@@ -1555,21 +1505,6 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write", traits = [
let assemblyFormat = "$x `=` $expr" # clausesReqAssemblyFormat # " oilist(" #
clausesOptAssemblyFormat # ") `:` type($x) `,` type($expr) attr-dict";
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getX() && "expected address operand");
- assert(getExpr() && "expected value operand");
- return 2;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i < 2 && "invalid index position for an operand");
- return i == 0 ? getX() : getExpr();
- }
- }] # clausesExtraClassDeclaration;
-
let hasVerifier = 1;
}
@@ -1614,20 +1549,6 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update", traits = [
let assemblyFormat = clausesAssemblyFormat #
"$x `:` type($x) $region attr-dict";
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getX() && "expected 'x' operand");
- return 1;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i == 0 && "invalid index position for an operand");
- return getX();
- }
- }] # clausesExtraClassDeclaration;
-
let hasVerifier = 1;
let hasRegionVerifier = 1;
let hasCanonicalizeMethod = 1;
@@ -1715,19 +1636,6 @@ def ThreadprivateOp : OpenMP_Op<"threadprivate",
let assemblyFormat = [{
$sym_addr `:` type($sym_addr) `->` type($tls_addr) attr-dict
}];
- let extraClassDeclaration = [{
- /// The number of variable operands.
- unsigned getNumVariableOperands() {
- assert(getSymAddr() && "expected one variable operand");
- return 1;
- }
-
- /// The i-th variable operand passed.
- Value getVariableOperand(unsigned i) {
- assert(i == 0 && "invalid index position for an operand");
- return getSymAddr();
- }
- }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 7888745dc6920..218260bd53a1e 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -28,262 +28,101 @@ namespace mlir {
using namespace mlir;
namespace {
-/// A pattern that converts the region arguments in a single-region OpenMP
-/// operation to the LLVM dialect. The body of the region is not modified and is
-/// expected to either be processed by the conversion infrastructure or already
-/// contain ops compatible with LLVM dialect types.
-template <typename OpType>
-struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
- using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto newOp = rewriter.create<OpType>(
- curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
- rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
- *this->getTypeConverter())))
- return failure();
-
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-template <typename T>
-struct RegionLessOpWithVarOperandsConversion
- : public ConvertOpToLLVMPattern<T> {
+/// A pattern that converts the result and operand types, attributes, and region
+/// arguments of an OpenMP operation to the LLVM dialect.
+///
+/// Attributes are copied verbatim by default, and only translated if they are
+/// type attributes.
+///
+/// Region bodies, if any, are not modified and expected to either be processed
+/// by the conversion infrastructure or already contain ops compatible with LLVM
+/// dialect types.
+template <typename T, bool SupportsMemRefOperand = true>
+struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(T curOp, typename T::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
- return failure();
- SmallVector<Value> convertedOperands;
- assert(curOp.getNumVariableOperands() ==
- curOp.getOperation()->getNumOperands() &&
- "unexpected non-variable operands");
- for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
- Value originalVariableOperand = curOp.getVariableOperand(idx);
- if (!originalVariableOperand)
- return failure();
- if (isa<MemRefType>(originalVariableOperand.getType())) {
- // TODO: Support memref type in variable operands
- return rewriter.notifyMatchFailure(curOp,
- "memref is not supported yet");
- }
- convertedOperands.emplace_back(adaptor.getOperands()[idx]);
- }
- rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
- curOp->getAttrs());
- return success();
- }
-};
-
-template <typename T>
-struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
- using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(T curOp, typename T::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
- return failure();
- SmallVector<Value> convertedOperands;
- assert(curOp.getNumVariableOperands() ==
- curOp.getOperation()->getNumOperands() &&
- "unexpected non-variable operands");
- for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
- Value originalVariableOperand = curOp.getVariableOperand(idx);
- if (!originalVariableOperand)
- return failure();
- if (isa<MemRefType>(originalVariableOperand.getType())) {
- // TODO: Support memref type in variable operands
- return rewriter.notifyMatchFailure(curOp,
- "memref is not supported yet");
- }
- convertedOperands.emplace_back(adaptor.getOperands()[idx]);
- }
- auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
- curOp->getAttrs());
- rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
- *this->getTypeConverter())))
- return failure();
-
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-template <typename T>
-struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
- using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(T curOp, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Translate result types.
const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
return failure();
- rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
- curOp->getAttrs());
- return success();
- }
-};
-
-struct AtomicReadOpConversion
- : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
- using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- Type curElementType = curOp.getElementType();
- auto newOp = rewriter.create<omp::AtomicReadOp>(
- curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
- TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
- newOp.setElementTypeAttr(typeAttr);
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
- using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-
- SmallVector<Type> resTypes;
- if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
- return failure();
-
- // Copy attributes of the curOp except for the typeAttr which should
- // be converted
- SmallVector<NamedAttribute> newAttrs;
+ // Translate type attributes.
+ // They are kept unmodified except if they are type attributes.
+ SmallVector<NamedAttribute> convertedAttrs;
for (NamedAttribute attr : curOp->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
- Type newAttr = converter->convertType(typeAttr.getValue());
- newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+ Type convertedType = converter->convertType(typeAttr.getValue());
+ convertedAttrs.emplace_back(attr.getName(),
+ TypeAttr::get(convertedType));
} else {
- newAttrs.push_back(attr);
+ convertedAttrs.push_back(attr);
}
}
- rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
- curOp, resTypes, adaptor.getOperands(), newAttrs);
- return success();
- }
-};
-
-struct DeclMapperOpConversion
- : public ConvertOpToLLVMPattern<omp::DeclareMapperOp> {
- using ConvertOpToLLVMPattern<omp::DeclareMapperOp>::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
- SmallVector<NamedAttribute> newAttrs;
- newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr());
- newAttrs.emplace_back(
- curOp.getTypeAttrName(),
- TypeAttr::get(converter->convertType(curOp.getType())));
-
- auto newOp = rewriter.create<omp::DeclareMapperOp>(
- curOp.getLoc(), TypeRange(), adaptor.getOperands(), newAttrs);
- rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
- *this->getTypeConverter())))
- return failure();
-
- rewriter.eraseOp(curOp);
- return success();
- }
-};
-
-template <typename OpType>
-struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
- using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
-
- void forwardOpAttrs(OpType curOp, OpType newOp) const {}
+ // Translate operands.
+ SmallVector<Value> convertedOperands;
+ convertedOperands.reserve(curOp->getNumOperands());
+ for (auto [originalOperand, convertedOperand] :
+ llvm::zip_equal(curOp->getOperands(), adaptor.getOperands())) {
+ if (!originalOperand)
+ return failure();
- LogicalResult
- matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto newOp = rewriter.create<OpType>(
- curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
- TypeAttr::get(this->getTypeConverter()->convertType(
- curOp.getTypeAttr().getValue())));
- forwardOpAttrs(curOp, newOp);
+ if constexpr (!SupportsMemRefOperand) {
+ if (isa<MemRefType>(originalOperand.getType())) {
+ // TODO: Support memref type in variable operands
+ return rewriter.notifyMatchFailure(curOp,
+ "memref is not supported yet");
+ }
+ }
+ convertedOperands.push_back(convertedOperand);
+ }
- for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
- rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
- newOp.getRegion(idx).end());
- if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
+ // Create new operation.
+ auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
+ convertedAttrs);
+
+ // Translate regions.
+ for (auto [originalRegion, convertedRegion] :
+ llvm::zip_equal(curOp->getRegions(), newOp->getRegions())) {
+ rewriter.inlineRegionBefore(originalRegion, convertedRegion,
+ convertedRegion.end());
+ if (failed(rewriter.convertRegionTypes(&convertedRegion,
*this->getTypeConverter())))
return failure();
}
- rewriter.eraseOp(curOp);
+ // Delete old operation and replace result uses with those of the new one.
+ rewriter.replaceOp(curOp, newOp->getResults());
return success();
}
};
-template <>
-void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
- omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
- newOp.setDataSharingType(curOp.getDataSharingType());
-}
} // namespace
void mlir::configureOpenMPToLLVMConversionLegality(
ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
target.addDynamicallyLegalOp<
- omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
- omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp,
- omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp,
- omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
- omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>(
- [&](Operation *op) {
- return typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes());
- });
- target.addDynamicallyLegalOp<
- omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp,
- omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp,
- omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp,
- omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp,
- omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp,
- omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
- omp::WsloopOp>([&](Operation *op) {
- return std::all_of(op->getRegions().begin(), op->getRegions().end(),
+#define GET_OP_LIST
+#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
+ >([&](Operation *op) {
+ return typeConverter.isLegal(op->getOperandTypes()) &&
+ typeConverter.isLegal(op->getResultTypes()) &&
+ std::all_of(op->getRegions().begin(), op->getRegions().end(),
[&](Region ®ion) {
return typeConverter.isLegal(®ion);
}) &&
- typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes());
+ std::all_of(op->getAttrs().begin(), op->getAttrs().end(),
+ [&](NamedAttribute attr) {
+ auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
+ return !typeAttr ||
+ typeConverter.isLegal(typeAttr.getValue());
+ });
});
- target.addDynamicallyLegalOp<omp::PrivateClauseOp>(
- [&](omp::PrivateClauseOp op) -> bool {
- return std::all_of(op->getRegions().begin(), op->getRegions().end(),
- [&](Region ®ion) {
- return typeConverter.isLegal(®ion);
- }) &&
- typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes()) &&
- typeConverter.isLegal(op.getType());
- });
}
void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
@@ -295,36 +134,42 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
[&](omp::MapBoundsType type) -> Type { return type; });
patterns.add<
- AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion,
- MultiRegionOpConversion<omp::DeclareReductionOp>,
- MultiRegionOpConversion<omp::PrivateClauseOp>,
- RegionLessOpConversion<omp::CancellationPointOp>,
- RegionLessOpConversion<omp::CancelOp>,
- RegionLessOpConversion<omp::CriticalDeclareOp>,
- RegionLessOpConversion<omp::DeclareMapperInfoOp>,
- RegionLessOpConversion<omp::OrderedOp>,
- RegionLessOpConversion<omp::ScanOp>,
- RegionLessOpConversion<omp::TargetEnterDataOp>,
- RegionLessOpConversion<omp::TargetExitDataOp>,
- RegionLessOpConversion<omp::TargetUpdateOp>,
- RegionLessOpConversion<omp::YieldOp>,
- RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
- RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
- RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
- RegionLessOpWithVarOperands...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/132009
More information about the Mlir-commits
mailing list