[flang-commits] [flang] 4a3bf27 - [OpenMP] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (#145464)
via flang-commits
flang-commits at lists.llvm.org
Mon Aug 18 05:45:15 PDT 2025
Author: Chaitanya
Date: 2025-08-18T18:15:11+05:30
New Revision: 4a3bf27c69473e65a9176858ff57c8b55dfb184c
URL: https://github.com/llvm/llvm-project/commit/4a3bf27c69473e65a9176858ff57c8b55dfb184c
DIFF: https://github.com/llvm/llvm-project/commit/4a3bf27c69473e65a9176858ff57c8b55dfb184c.diff
LOG: [OpenMP] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (#145464)
This PR introduces two new ops in omp dialect, omp.target_allocmem and
omp.target_freemem.
omp.target_allocmem: Allocates heap memory on device. Will be lowered to
omp_target_alloc call in llvm.
omp.target_freemem: Deallocates heap memory on device. Will be lowered
to omp+target_free call in llvm.
Example:
%1 = omp.target_allocmem %device : i32, i64
omp.target_freemem %device, %1 : i32, i64
The work in this PR is C-P/inspired from @ivanradanov commit from
coexecute implementation:
[Add fir omp target alloc and free
ops](https://github.com/ivanradanov/llvm-project/commit/be860ac8baf24b8405e6f396c75d7f0d26375de5)
[Lower omp_target_{alloc,free} to
llvm](https://github.com/ivanradanov/llvm-project/commit/6e2d584dc93ff99bb89adc28c7afbc2b21c46d39)
Added:
flang/test/Fir/omp_target_allocmem_freemem.fir
mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir
Modified:
flang/include/flang/Optimizer/Support/Utils.h
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Support/Utils.cpp
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index 83c936b7dcada..0b31cfea0430a 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -27,6 +27,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+
namespace fir {
/// Return the integer value of a arith::ConstantOp.
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
@@ -198,6 +200,37 @@ std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
fir::RecordType recordType, llvm::StringRef component,
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);
+/// Generate a LLVM constant value of type `ity`, using the provided offset.
+mlir::LLVM::ConstantOp
+genConstantIndex(mlir::Location loc, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter,
+ std::int64_t offset);
+
+/// Helper function for generating the LLVM IR that computes the distance
+/// in bytes between adjacent elements pointed to by a pointer
+/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
+/// type.
+mlir::Value computeElementDistance(mlir::Location loc,
+ mlir::Type llvmObjectType, mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ const mlir::DataLayout &dataLayout);
+
+// Compute the alloc scale size (constant factors encoded in the array type).
+// We do this for arrays without a constant interior or arrays of character with
+// dynamic length arrays, since those are the only ones that get decayed to a
+// pointer to the element type.
+mlir::Value genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy,
+ mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter);
+
+/// Perform an extension or truncation as needed on an integer value. Lowering
+/// to the specific target may involve some sign-extending or truncation of
+/// values, particularly to fit them from abstract box types to the
+/// appropriate reified structures.
+mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
+ mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type ty, mlir::Value val, bool fold = false);
} // namespace fir
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ba5fef97c83ed..76f3cbd421cb9 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -87,14 +87,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
return mlir::IntegerType::get(context, 8);
}
-static mlir::LLVM::ConstantOp
-genConstantIndex(mlir::Location loc, mlir::Type ity,
- mlir::ConversionPatternRewriter &rewriter,
- std::int64_t offset) {
- auto cattr = rewriter.getI64IntegerAttr(offset);
- return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr);
-}
-
static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
mlir::Block *insertBefore) {
assert(insertBefore && "expected valid insertion block");
@@ -208,39 +200,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op,
TODO(op.getLoc(), "did not find allocation function");
}
-// Compute the alloc scale size (constant factors encoded in the array type).
-// We do this for arrays without a constant interior or arrays of character with
-// dynamic length arrays, since those are the only ones that get decayed to a
-// pointer to the element type.
-template <typename OP>
-static mlir::Value
-genAllocationScaleSize(OP op, mlir::Type ity,
- mlir::ConversionPatternRewriter &rewriter) {
- mlir::Location loc = op.getLoc();
- mlir::Type dataTy = op.getInType();
- auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
- fir::SequenceType::Extent constSize = 1;
- if (seqTy) {
- int constRows = seqTy.getConstantRows();
- const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
- if (constRows != static_cast<int>(shape.size())) {
- for (auto extent : shape) {
- if (constRows-- > 0)
- continue;
- if (extent != fir::SequenceType::getUnknownExtent())
- constSize *= extent;
- }
- }
- }
-
- if (constSize != 1) {
- mlir::Value constVal{
- genConstantIndex(loc, ity, rewriter, constSize).getResult()};
- return constVal;
- }
- return nullptr;
-}
-
namespace {
struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> {
public:
@@ -275,7 +234,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
auto loc = alloc.getLoc();
mlir::Type ity = lowerTy().indexType();
unsigned i = 0;
- mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult();
+ mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult();
mlir::Type firObjType = fir::unwrapRefType(alloc.getType());
mlir::Type llvmObjectType = convertObjectType(firObjType);
if (alloc.hasLenParams()) {
@@ -307,7 +266,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
<< scalarType << " with type parameters";
}
}
- if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter))
+ if (auto scaleSize = fir::genAllocationScaleSize(
+ alloc.getLoc(), alloc.getInType(), ity, rewriter))
size =
rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
if (alloc.hasShapeOperands()) {
@@ -484,7 +444,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> {
auto loc = boxisarray.getLoc();
TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType());
mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter);
- mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0);
+ mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0);
return mlir::success();
@@ -820,7 +780,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
// Do folding for constant inputs.
if (auto constVal = fir::getIntIfConstant(op0)) {
mlir::Value normVal =
- genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
+ fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
rewriter.replaceOp(convert, normVal);
return mlir::success();
}
@@ -833,7 +793,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
}
// Compare the input with zero.
- mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
+ mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0);
auto isTrue = mlir::LLVM::ICmpOp::create(
rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
@@ -1082,21 +1042,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
return getMallocInModule(mod, op, rewriter, indexType);
}
-/// Helper function for generating the LLVM IR that computes the distance
-/// in bytes between adjacent elements pointed to by a pointer
-/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
-/// type.
-static mlir::Value
-computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
- mlir::Type idxTy,
- mlir::ConversionPatternRewriter &rewriter,
- const mlir::DataLayout &dataLayout) {
- llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
- unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
- std::int64_t distance = llvm::alignTo(size, alignment);
- return genConstantIndex(loc, idxTy, rewriter, distance);
-}
-
/// Return value of the stride in bytes between adjacent elements
/// of LLVM type \p llTy. The result is returned as a value of
/// \p idxTy integer type.
@@ -1105,7 +1050,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
const mlir::DataLayout &dataLayout) {
// Create a pointer type and use computeElementDistance().
- return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
+ return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
}
namespace {
@@ -1124,8 +1069,9 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
TODO(loc, "fir.allocmem codegen of derived type with length parameters");
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
- if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
- size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
+ if (auto scaleSize =
+ fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter))
+ size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
for (mlir::Value opnd : adaptor.getOperands())
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size,
integerCast(loc, rewriter, ity, opnd));
@@ -1133,8 +1079,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
// As the return value of malloc(0) is implementation defined, allocate one
// byte to ensure the allocation status being true. This behavior aligns to
// what the runtime has.
- mlir::Value zero = genConstantIndex(loc, ity, rewriter, 0);
- mlir::Value one = genConstantIndex(loc, ity, rewriter, 1);
+ mlir::Value zero = fir::genConstantIndex(loc, ity, rewriter, 0);
+ mlir::Value one = fir::genConstantIndex(loc, ity, rewriter, 1);
mlir::Value cmp = mlir::LLVM::ICmpOp::create(
rewriter, loc, mlir::LLVM::ICmpPredicate::sgt, size, zero);
size = mlir::LLVM::SelectOp::create(rewriter, loc, cmp, size, one);
@@ -1157,7 +1103,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) const {
- return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
+ return fir::computeElementDistance(loc, llTy, idxTy, rewriter,
+ getDataLayout());
}
};
} // namespace
@@ -1344,7 +1291,7 @@ genCUFAllocDescriptor(mlir::Location loc,
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
mlir::Value sizeInBytes =
- genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
+ fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
return mlir::LLVM::CallOp::create(rewriter, loc, fctTy,
RTNAME_STRING(CUFAllocDescriptor), args)
@@ -1599,7 +1546,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
// representation of derived types with pointer/allocatable components.
// This has been seen in hashing algorithms using TRANSFER.
mlir::Value zero =
- genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
+ fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
descriptor = insertField(rewriter, loc, descriptor,
{getLenParamFieldId(boxTy), 0}, zero);
}
@@ -1944,8 +1891,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
bool hasSlice = !xbox.getSlice().empty();
unsigned sliceOffset = xbox.getSliceOperandIndex();
mlir::Location loc = xbox.getLoc();
- mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0);
- mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1);
+ mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0);
+ mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1);
mlir::Value prevPtrOff = one;
mlir::Type eleTy = boxTy.getEleTy();
const unsigned rank = xbox.getRank();
@@ -1994,7 +1941,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
prevDimByteStride =
getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams());
} else {
- prevDimByteStride = genConstantIndex(
+ prevDimByteStride = fir::genConstantIndex(
loc, i64Ty, rewriter,
charTy.getLen() * lowerTy().characterBitsize(charTy) / 8);
}
@@ -2152,7 +2099,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) {
if (charTy.hasConstantLen()) {
mlir::Value len =
- genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
+ fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
lenParams.emplace_back(len);
} else {
mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair,
@@ -2161,7 +2108,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
assert(!isInGlobalOp(rewriter) &&
"character target in global op must have constant length");
mlir::Value width =
- genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
+ fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width);
}
lenParams.emplace_back(len);
@@ -2215,8 +2162,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
mlir::Value zero =
- genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
- mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
+ fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
+ mlir::Value one =
+ fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
mlir::Value extent = std::get<0>(iter.value());
unsigned dim = iter.index();
@@ -2249,7 +2197,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
mlir::Location loc = rebox.getLoc();
mlir::Type byteTy = ::getI8Type(rebox.getContext());
mlir::Type idxTy = lowerTy().indexType();
- mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
+ mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0);
// Apply subcomponent and substring shift on base address.
if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) {
// Cast to inputEleTy* so that a GEP can be used.
@@ -2277,7 +2225,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
// and strides.
llvm::SmallVector<mlir::Value> slicedExtents;
llvm::SmallVector<mlir::Value> slicedStrides;
- mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
+ mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
const bool sliceHasOrigins = !rebox.getShift().empty();
unsigned sliceOps = rebox.getSliceOperandIndex();
unsigned shiftOps = rebox.getShiftOperandIndex();
@@ -2350,7 +2298,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
// which may be OK if all new extents are ones, the stride does not
// matter, use one.
mlir::Value stride = inputStrides.empty()
- ? genConstantIndex(loc, idxTy, rewriter, 1)
+ ? fir::genConstantIndex(loc, idxTy, rewriter, 1)
: inputStrides[0];
for (unsigned i = 0; i < rebox.getShape().size(); ++i) {
mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i];
@@ -2585,9 +2533,9 @@ struct XArrayCoorOpConversion
unsigned shiftOffset = coor.getShiftOperandIndex();
unsigned sliceOffset = coor.getSliceOperandIndex();
auto sliceOps = coor.getSlice().begin();
- mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
+ mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
mlir::Value prevExt = one;
- mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0);
+ mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0);
const bool isShifted = !coor.getShift().empty();
const bool isSliced = !coor.getSlice().empty();
const bool baseIsBoxed =
@@ -2918,7 +2866,7 @@ struct CoordinateOpConversion
// of lower bound aspects. This both accounts for dynamically sized
// types and non contiguous arrays.
auto idxTy = lowerTy().indexType();
- mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0);
+ mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0);
unsigned arrayDim = arrTy.getDimension();
for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) {
mlir::Value stride =
@@ -3846,7 +3794,7 @@ struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> {
ptr = mlir::LLVM::ExtractValueOp::create(rewriter, loc, ptr, 0);
}
mlir::LLVM::ConstantOp c0 =
- genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
+ fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0);
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 37f1c9f97e1ce..97912bda79b08 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -21,6 +21,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -125,10 +126,58 @@ struct PrivateClauseOpConversion
return mlir::success();
}
};
+
+// Convert FIR type to LLVM without turning fir.box<T> into memory
+// reference.
+static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
+ mlir::Type firType) {
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
+ return converter.convertBoxTypeAsStruct(boxTy);
+ return converter.convertType(firType);
+}
+
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
+ : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
+ using OpenMPFIROpConversion::OpenMPFIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Type heapTy = allocmemOp.getAllocatedType();
+ mlir::Location loc = allocmemOp.getLoc();
+ auto ity = lowerTy().indexType();
+ mlir::Type dataTy = fir::unwrapRefType(heapTy);
+ mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+ if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
+ TODO(loc, "omp.target_allocmem codegen of derived type with length "
+ "parameters");
+ mlir::Value size = fir::computeElementDistance(
+ loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
+ if (auto scaleSize = fir::genAllocationScaleSize(
+ loc, allocmemOp.getInType(), ity, rewriter))
+ size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+ for (mlir::Value opnd : adaptor.getOperands().drop_front())
+ size = rewriter.create<mlir::LLVM::MulOp>(
+ loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
+ auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+ auto mallocTy =
+ mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
+ if (mallocTyWidth != ity.getIntOrFloatBitWidth())
+ size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
+ rewriter.modifyOpInPlace(allocmemOp, [&]() {
+ allocmemOp.setInType(rewriter.getI8Type());
+ allocmemOp.getTypeparamsMutable().clear();
+ allocmemOp.getTypeparamsMutable().append(size);
+ });
+ return mlir::success();
+ }
+};
} // namespace
void fir::populateOpenMPFIRToLLVMConversionPatterns(
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
+ patterns.add<TargetAllocMemOpConversion>(converter);
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 01975f357a8da..87f9899aa7879 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) {
}
/// Parser shared by Alloca and Allocmem
-///
/// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type
/// ( `(` $typeparams `)` )? ( `,` $shape )?
/// attr-dict-without-keyword
diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp
index 5d663e28336c0..c71642ce4e806 100644
--- a/flang/lib/Optimizer/Support/Utils.cpp
+++ b/flang/lib/Optimizer/Support/Utils.cpp
@@ -50,3 +50,74 @@ std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault(
return componentInfo.getLowerBounds();
return std::nullopt;
}
+
+mlir::LLVM::ConstantOp
+fir::genConstantIndex(mlir::Location loc, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter,
+ std::int64_t offset) {
+ auto cattr = rewriter.getI64IntegerAttr(offset);
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+}
+
+mlir::Value
+fir::computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
+ mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ const mlir::DataLayout &dataLayout) {
+ llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
+ unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
+ std::int64_t distance = llvm::alignTo(size, alignment);
+ return fir::genConstantIndex(loc, idxTy, rewriter, distance);
+}
+
+mlir::Value
+fir::genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy,
+ mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter) {
+ auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
+ fir::SequenceType::Extent constSize = 1;
+ if (seqTy) {
+ int constRows = seqTy.getConstantRows();
+ const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
+ if (constRows != static_cast<int>(shape.size())) {
+ for (auto extent : shape) {
+ if (constRows-- > 0)
+ continue;
+ if (extent != fir::SequenceType::getUnknownExtent())
+ constSize *= extent;
+ }
+ }
+ }
+
+ if (constSize != 1) {
+ mlir::Value constVal{
+ fir::genConstantIndex(loc, ity, rewriter, constSize).getResult()};
+ return constVal;
+ }
+ return nullptr;
+}
+
+mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter,
+ mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type ty, mlir::Value val, bool fold) {
+ auto valTy = val.getType();
+ // If the value was not yet lowered, lower its type so that it can
+ // be used in getPrimitiveTypeSizeInBits.
+ if (!mlir::isa<mlir::IntegerType>(valTy))
+ valTy = converter.convertType(valTy);
+ auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+ auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
+ if (fold) {
+ if (toSize < fromSize)
+ return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
+ } else {
+ if (toSize < fromSize)
+ return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+ }
+ return val;
+}
diff --git a/flang/test/Fir/omp_target_allocmem_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir
new file mode 100644
index 0000000000000..03eb94acb1ac7
--- /dev/null
+++ b/flang/test/Fir/omp_target_allocmem_freemem.fir
@@ -0,0 +1,294 @@
+// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
+
+// UNSUPPORTED: system-windows
+// Disabled on 32-bit targets due to the additional `trunc` opcodes required
+// UNSUPPORTED: target-x86
+// UNSUPPORTED: target=sparc-{{.*}}
+// UNSUPPORTED: target=sparcel-{{.*}}
+
+// CHECK-LABEL: define void @omp_target_allocmem_scalar_nonchar() {
+// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 4, i32 0)
+// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_scalar_nonchar() -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, i32
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_scalars_nonchar() {
+// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 400, i32 0)
+// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_scalars_nonchar() -> () {
+ %device = arith.constant 0 : i32
+ %0 = arith.constant 100 : index
+ %1 = omp.target_allocmem %device : i32, i32, %0
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_scalar_char() {
+// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 10, i32 0)
+// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_scalar_char() -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.char<1,10>
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_scalar_char_kind() {
+// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 20, i32 0)
+// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_scalar_char_kind() -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.char<2,10>
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_scalar_dynchar(
+// CHECK-SAME: i32 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0)
+// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_scalar_dynchar(%l : i32) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.char<1,?>(%l : i32)
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+
+// CHECK-LABEL: define void @omp_target_allocmem_scalar_dynchar_kind(
+// CHECK-SAME: i32 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 2, [[TMP2]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0)
+// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_scalar_dynchar_kind(%l : i32) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.char<2,?>(%l : i32)
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar() {
+// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 36, i32 0)
+// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_array_of_nonchar() -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32>
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_of_char() {
+// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 90, i32 0)
+// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_array_of_char() -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>>
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar(
+// CHECK-SAME: i32 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 9, [[TMP2]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0)
+// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+
+// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_nonchar(
+// CHECK-SAME: i64 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = mul i64 12, [[TMP0]]
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]]
+// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0)
+// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64
+// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_dynarray_of_nonchar(%e: index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x?xi32>, %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_nonchar2(
+// CHECK-SAME: i64 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = mul i64 4, [[TMP0]]
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], [[TMP0]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0)
+// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_dynarray_of_nonchar2(%e: index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<?x?xi32>, %e, %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_char(
+// CHECK-SAME: i64 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = mul i64 60, [[TMP0]]
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]]
+// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0)
+// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64
+// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_dynarray_of_char(%e : index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x?x!fir.char<2,10>>, %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+
+// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_char2(
+// CHECK-SAME: i64 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = mul i64 20, [[TMP0]]
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], [[TMP0]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0)
+// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_dynarray_of_char2(%e : index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<?x?x!fir.char<2,10>>, %e, %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_dynchar(
+// CHECK-SAME: i32 [[TMP0:%.*]], i64 [[TMP1:%.*]]) {
+// CHECK-NEXT: [[TMP3:%.*]] = sext i32 [[TMP0]] to i64
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 6, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], [[TMP1]]
+// CHECK-NEXT: [[TMP6:%.*]] = mul i64 1, [[TMP5]]
+// CHECK-NEXT: [[TMP7:%.*]] = call ptr @omp_target_alloc(i64 [[TMP6]], i32 0)
+// CHECK-NEXT: [[TMP8:%.*]] = ptrtoint ptr [[TMP7]] to i64
+// CHECK-NEXT: [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP9]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_dynarray_of_dynchar(%l: i32, %e : index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x?x!fir.char<2,?>>(%l : i32), %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_dynchar2(
+// CHECK-SAME: i32 [[TMP0:%.*]], i64 [[TMP1:%.*]]) {
+// CHECK-NEXT: [[TMP3:%.*]] = sext i32 [[TMP0]] to i64
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 2, [[TMP3]]
+// CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], [[TMP1]]
+// CHECK-NEXT: [[TMP6:%.*]] = mul i64 [[TMP5]], [[TMP1]]
+// CHECK-NEXT: [[TMP7:%.*]] = mul i64 1, [[TMP6]]
+// CHECK-NEXT: [[TMP8:%.*]] = call ptr @omp_target_alloc(i64 [[TMP7]], i32 0)
+// CHECK-NEXT: [[TMP9:%.*]] = ptrtoint ptr [[TMP8]] to i64
+// CHECK-NEXT: [[TMP10:%.*]] = inttoptr i64 [[TMP9]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP10]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_dynarray_of_dynchar2(%l: i32, %e : index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<?x?x!fir.char<2,?>>(%l : i32), %e, %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_nonchar(
+// CHECK-SAME: i64 [[TMP0:%.*]], i64 [[TMP1:%.*]]) {
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 240, [[TMP0]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP3]], [[TMP1]]
+// CHECK-NEXT: [[TMP5:%.*]] = mul i64 1, [[TMP4]]
+// CHECK-NEXT: [[TMP6:%.*]] = call ptr @omp_target_alloc(i64 [[TMP5]], i32 0)
+// CHECK-NEXT: [[TMP7:%.*]] = ptrtoint ptr [[TMP6]] to i64
+// CHECK-NEXT: [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP8]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_array_with_holes_nonchar(%0 : index, %1 : index) -> () {
+ %device = arith.constant 0 : i32
+ %2 = omp.target_allocmem %device : i32, !fir.array<4x?x3x?x5xi32>, %0, %1
+ omp.target_freemem %device, %2 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_char(
+// CHECK-SAME: i64 [[TMP0:%.*]]) {
+// CHECK-NEXT: [[TMP2:%.*]] = mul i64 240, [[TMP0]]
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]]
+// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0)
+// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64
+// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_array_with_holes_char(%e: index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x?x4x!fir.char<2,10>>, %e
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_dynchar(
+// CHECK-SAME: i64 [[TMP0:%.*]], i64 [[TMP1:%.*]]) {
+// CHECK-NEXT: [[TMP3:%.*]] = mul i64 24, [[TMP0]]
+// CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP3]], [[TMP1]]
+// CHECK-NEXT: [[TMP5:%.*]] = mul i64 1, [[TMP4]]
+// CHECK-NEXT: [[TMP6:%.*]] = call ptr @omp_target_alloc(i64 [[TMP5]], i32 0)
+// CHECK-NEXT: [[TMP7:%.*]] = ptrtoint ptr [[TMP6]] to i64
+// CHECK-NEXT: [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr
+// CHECK-NEXT: call void @omp_target_free(ptr [[TMP8]], i32 0)
+// CHECK-NEXT: ret void
+func.func @omp_target_allocmem_array_with_holes_dynchar(%arg0: index, %arg1: index) -> () {
+ %device = arith.constant 0 : i32
+ %1 = omp.target_allocmem %device : i32, !fir.array<3x?x4x!fir.char<2,?>>(%arg0 : index), %arg1
+ omp.target_freemem %device, %1 : i32, i64
+ return
+}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index be114ea4fb631..c956d69781b3d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2115,4 +2115,98 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
+ [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
+ let summary = "allocate storage on an openmp device for an object of a given type";
+
+ let description = [{
+ Allocates memory on the specified OpenMP device for an object of the given type.
+ Returns an integer value representing the device pointer to the allocated memory.
+ The memory is uninitialized after allocation. Operations must be paired with
+ `omp.target_freemem` to avoid memory leaks.
+
+ * `$device`: The integer ID of the OpenMP device where the memory will be allocated.
+ * `$in_type`: The type of the object for which memory is being allocated.
+ For arrays, this can be a static or dynamic array type.
+ * `$uniq_name`: An optional unique name for the allocated memory.
+ * `$bindc_name`: An optional name used for C interoperability.
+ * `$typeparams`: Runtime type parameters for polymorphic or parameterized types.
+ These are typically integer values that define aspects of a type not fixed at compile time.
+ * `$shape`: Runtime shape operands for dynamic arrays.
+ Each operand is an integer value representing the extent of a specific dimension.
+
+ ```mlir
+ // Allocate a static 3x3 integer vector on device 0
+ %device_0 = arith.constant 0 : i32
+ %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
+ // ... use %ptr_static ...
+ omp.target_freemem %device_0, %ptr_static : i32, i64
+
+ // Allocate a dynamic 2D Fortran array (fir.array) on device 1
+ %device_1 = arith.constant 1 : i32
+ %rows = arith.constant 10 : index
+ %cols = arith.constant 20 : index
+ %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
+ // ... use %ptr_dynamic ...
+ omp.target_freemem %device_1, %ptr_dynamic : i32, i64
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyInteger>:$device,
+ TypeAttr:$in_type,
+ OptionalAttr<StrAttr>:$uniq_name,
+ OptionalAttr<StrAttr>:$bindc_name,
+ Variadic<IntLikeType>:$typeparams,
+ Variadic<IntLikeType>:$shape
+ );
+ let results = (outs I64);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ mlir::Type getAllocatedType();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// TargetFreeMemOp
+//===----------------------------------------------------------------------===//
+
+def TargetFreeMemOp : OpenMP_Op<"target_freemem",
+ [MemoryEffects<[MemFree]>]> {
+ let summary = "free memory on an openmp device";
+
+ let description = [{
+ Deallocates memory on the specified OpenMP device that was previously
+ allocated by an `omp.target_allocmem` operation. After this operation, the
+ deallocated memory is in an undefined state and should not be accessed.
+ It is crucial to ensure that all accesses to the memory region are completed
+ before `omp.target_freemem` is called to avoid undefined behavior.
+
+ * `$device`: The integer ID of the OpenMP device from which the memory will be freed.
+ * `$heapref`: The integer value representing the device pointer to the memory
+ to be deallocated, which was previously returned by `omp.target_allocmem`.
+
+ ```mlir
+ // Example of allocating and freeing memory on an OpenMP device
+ %device_id = arith.constant 0 : i32
+ %allocated_ptr = omp.target_allocmem %device_id : i32, vector<3x3xi32>
+ // ... operations using %allocated_ptr on the device ...
+ omp.target_freemem %device_id, %allocated_ptr : i32, i64
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyInteger, "", [MemFree]>:$device,
+ Arg<I64, "", [MemFree]>:$heapref
+ );
+ let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
+}
+
#endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767ef90b0..fa94219016c1e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3874,6 +3874,107 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
+ return getInTypeAttr().getValue();
+}
+
+/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
+/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+/// attr-dict-without-keyword
+static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto &builder = parser.getBuilder();
+ bool hasOperands = false;
+ std::int32_t typeparamsSize = 0;
+
+ // Parse device number as a new operand
+ mlir::OpAsmParser::UnresolvedOperand deviceOperand;
+ mlir::Type deviceType;
+ if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
+ return mlir::failure();
+ if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
+ return mlir::failure();
+ if (parser.parseComma())
+ return mlir::failure();
+
+ mlir::Type intype;
+ if (parser.parseType(intype))
+ return mlir::failure();
+ result.addAttribute("in_type", mlir::TypeAttr::get(intype));
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::Type> typeVec;
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeVec) || parser.parseRParen())
+ return mlir::failure();
+ typeparamsSize = operands.size();
+ hasOperands = true;
+ }
+ std::int32_t shapeSize = 0;
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
+ return mlir::failure();
+ shapeSize = operands.size() - typeparamsSize;
+ auto idxTy = builder.getIndexType();
+ for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
+ typeVec.push_back(idxTy);
+ hasOperands = true;
+ }
+ if (hasOperands &&
+ parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
+ result.operands))
+ return mlir::failure();
+
+ mlir::Type restype = builder.getIntegerType(64);
+ if (!restype) {
+ parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
+ return mlir::failure();
+ }
+ llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
+ result.addAttribute("operandSegmentSizes",
+ builder.getDenseI32ArrayAttr(segmentSizes));
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.addTypeToList(restype, result.types))
+ return mlir::failure();
+ return mlir::success();
+}
+
+mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseTargetAllocMemOp(parser, result);
+}
+
+void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getDevice());
+ p << " : ";
+ p << getDevice().getType();
+ p << ", ";
+ p << getInType();
+ if (!getTypeparams().empty()) {
+ p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
+ }
+ for (auto sh : getShape()) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"in_type", "operandSegmentSizes"});
+}
+
+llvm::LogicalResult omp::TargetAllocMemOp::verify() {
+ mlir::Type outType = getType();
+ if (!mlir::dyn_cast<IntegerType>(outType))
+ return emitOpError("must be a integer type");
+ return mlir::success();
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index eb96cb211fdd5..6694de8383534 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5867,6 +5867,10 @@ static bool isTargetDeviceOp(Operation *op) {
if (mlir::isa<omp::ThreadprivateOp>(op))
return true;
+ if (mlir::isa<omp::TargetAllocMemOp>(op) ||
+ mlir::isa<omp::TargetFreeMemOp>(op))
+ return true;
+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
@@ -5879,6 +5883,85 @@ static bool isTargetDeviceOp(Operation *op) {
return false;
}
+static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *i64Ty = builder.getInt64Ty();
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *returnType = builder.getPtrTy(0);
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
+ llvm::Function *func = cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
+ if (!allocMemOp)
+ return failure();
+
+ // Get "omp_target_alloc" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = allocMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the allocation size.
+ llvm::DataLayout dataLayout = llvmModule->getDataLayout();
+ mlir::Type heapTy = allocMemOp.getAllocatedType();
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
+ llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ for (auto typeParam : allocMemOp.getTypeparams())
+ allocSize =
+ builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+ // Create call to "omp_target_alloc" with the args as translated llvm values.
+ llvm::CallInst *call =
+ builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
+ llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
+
+ // Map the result
+ moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
+ return success();
+}
+
+static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *voidTy = builder.getVoidTy();
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
+ llvm::Function *func = dyn_cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
+ if (!freeMemOp)
+ return failure();
+
+ // Get "omp_target_free" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = freeMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the corresponding heapref value in llvm
+ mlir::Value heapref = freeMemOp.getHeapref();
+ llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
+ // Convert heapref int to ptr and call "omp_target_free"
+ llvm::Value *intToPtr =
+ builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
+ builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
+ return success();
+}
+
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
/// OpenMP runtime calls).
static LogicalResult
@@ -6053,6 +6136,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TargetAllocMemOp) {
+ return convertTargetAllocMemOp(*op, builder, moduleTranslation);
+ })
+ .Case([&](omp::TargetFreeMemOp) {
+ return convertTargetFreeMemOp(*op, builder, moduleTranslation);
+ })
.Default([&](Operation *inst) {
return inst->emitError()
<< "not yet implemented: " << inst->getName();
diff --git a/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir
new file mode 100644
index 0000000000000..1bc97609ccff4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -convert-openmp-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s
+
+// This file contains MLIR test cases for omp.target_allocmem and omp.target_freemem
+
+// CHECK-LABEL: test_alloc_free_i64
+// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 8, i32 0)
+// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64
+// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr
+// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0)
+// CHECK: ret void
+llvm.func @test_alloc_free_i64() -> () {
+ %device = llvm.mlir.constant(0 : i32) : i32
+ %1 = omp.target_allocmem %device : i32, i64
+ omp.target_freemem %device, %1 : i32, i64
+ llvm.return
+}
+
+// CHECK-LABEL: test_alloc_free_vector_1d_f32
+// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 64, i32 0)
+// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64
+// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr
+// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0)
+// CHECK: ret void
+llvm.func @test_alloc_free_vector_1d_f32() -> () {
+ %device = llvm.mlir.constant(0 : i32) : i32
+ %1 = omp.target_allocmem %device : i32, vector<16xf32>
+ omp.target_freemem %device, %1 : i32, i64
+ llvm.return
+}
+
+// CHECK-LABEL: test_alloc_free_vector_2d_f32
+// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 1024, i32 0)
+// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64
+// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr
+// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0)
+// CHECK: ret void
+llvm.func @test_alloc_free_vector_2d_f32() -> () {
+ %device = llvm.mlir.constant(0 : i32) : i32
+ %1 = omp.target_allocmem %device : i32, vector<16x16xf32>
+ omp.target_freemem %device, %1 : i32, i64
+ llvm.return
+}
More information about the flang-commits
mailing list