[flang-commits] [flang] [flang] Introduce omp_target_allocmem and omp_target_freemem fir ops. (PR #145464)
via flang-commits
flang-commits at lists.llvm.org
Wed Jun 25 21:51:58 PDT 2025
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/145464
>From 4915fb8f1d27847dc3d36899e233d5ac988f96c5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 23 Jun 2025 16:39:55 +0530
Subject: [PATCH 1/2] [flang] Introduce omp_target_allocmem and
omp_target_freemem fir ops.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 58 ++++++++++
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 102 +++++++++++++++++-
2 files changed, 159 insertions(+), 1 deletion(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8ac847dd7dd0a..2dff0f05fade7 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -517,6 +517,64 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> {
let assemblyFormat = "type($intype) attr-dict";
}
+def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
+ [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
+ let summary = "allocate storage on an openmp device for an object of a given type";
+
+ let description = [{
+ Creates a heap memory reference suitable for storing a value of the
+ given type, T. The heap refernce returned has type `!fir.heap<T>`.
+ The memory object is in an undefined state. `omp_target_allocmem` operations must
+ be paired with `omp_target_freemem` operations to avoid memory leaks.
+
+ ```
+ %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyIntegerType>:$device,
+ TypeAttr:$in_type,
+ OptionalAttr<StrAttr>:$uniq_name,
+ OptionalAttr<StrAttr>:$bindc_name,
+ Variadic<AnyIntegerType>:$typeparams,
+ Variadic<AnyIntegerType>:$shape
+ );
+ let results = (outs fir_HeapType);
+
+ let extraClassDeclaration = [{
+ mlir::Type getAllocatedType();
+ bool hasLenParams() { return !getTypeparams().empty(); }
+ bool hasShapeOperands() { return !getShape().empty(); }
+ unsigned numLenParams() { return getTypeparams().size(); }
+ operand_range getLenParams() { return getTypeparams(); }
+ unsigned numShapeOperands() { return getShape().size(); }
+ operand_range getShapeOperands() { return getShape(); }
+ static mlir::Type getRefTy(mlir::Type ty);
+ }];
+}
+
+def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
+ [MemoryEffects<[MemFree]>]> {
+ let summary = "free a heap object";
+
+ let description = [{
+ Deallocates a heap memory reference that was allocated by an `omp_target_allocmem`.
+ The memory object that is deallocated is placed in an undefined state
+ after `fir.omp_target_freemem`.
+ ```
+ %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
+ ...
+ "fir.omp_target_freemem"(%device, %0) : (i32, !fir.heap<!fir.array<?xf32>>) -> ()
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyIntegerType, "", [MemFree]>:$device,
+ Arg<fir_HeapType, "", [MemFree]>:$heapref
+ );
+}
+
//===----------------------------------------------------------------------===//
// Terminator operations
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..042ade6b1e0a1 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -1168,6 +1168,105 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
};
} // namespace
+static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ if (mlir::LLVM::LLVMFuncOp mallocFunc =
+ module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
+ return mallocFunc;
+ mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
+ auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
+ return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+ moduleBuilder.getUnknownLoc(), "omp_target_alloc",
+ mlir::LLVM::LLVMFunctionType::get(
+ mlir::LLVM::LLVMPointerType::get(module->getContext()),
+ {i64Ty, i32Ty},
+ /*isVarArg=*/false));
+}
+
+namespace {
+struct OmpTargetAllocMemOpConversion
+ : public fir::FIROpConversion<fir::OmpTargetAllocMemOp> {
+ using FIROpConversion::FIROpConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Type heapTy = heap.getType();
+ mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap);
+ mlir::Location loc = heap.getLoc();
+ auto ity = lowerTy().indexType();
+ mlir::Type dataTy = fir::unwrapRefType(heapTy);
+ mlir::Type llvmObjectTy = convertObjectType(dataTy);
+ if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
+ TODO(loc, "fir.omp_target_allocmem codegen of derived type with length "
+ "parameters");
+ mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
+ if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
+ size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+ for (mlir::Value opnd : adaptor.getOperands().drop_front())
+ size = rewriter.create<mlir::LLVM::MulOp>(
+ loc, ity, size, integerCast(loc, rewriter, ity, opnd));
+ auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+ auto mallocTy =
+ mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
+ if (mallocTyWidth != ity.getIntOrFloatBitWidth())
+ size = integerCast(loc, rewriter, mallocTy, size);
+ heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
+ rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
+ heap, ::getLlvmPtrType(heap.getContext()),
+ mlir::SmallVector<mlir::Value, 2>({size, heap.getDevice()}),
+ addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2));
+ return mlir::success();
+ }
+
+ /// Compute the allocation size in bytes of the element type of
+ /// \p llTy pointer type. The result is returned as a value of \p idxTy
+ /// integer type.
+ mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type llTy) const {
+ return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
+ }
+};
+} // namespace
+
+static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) {
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ if (mlir::LLVM::LLVMFuncOp freeFunc =
+ module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_free"))
+ return freeFunc;
+ mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
+ return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+ moduleBuilder.getUnknownLoc(), "omp_target_free",
+ mlir::LLVM::LLVMFunctionType::get(
+ mlir::LLVM::LLVMVoidType::get(module->getContext()),
+ {getLlvmPtrType(module->getContext()), i32Ty},
+ /*isVarArg=*/false));
+}
+
+namespace {
+struct OmpTargetFreeMemOpConversion
+ : public fir::FIROpConversion<fir::OmpTargetFreeMemOp> {
+ using FIROpConversion::FIROpConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem);
+ mlir::Location loc = freemem.getLoc();
+ freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
+ rewriter.create<mlir::LLVM::CallOp>(
+ loc, mlir::TypeRange{},
+ mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()},
+ addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2));
+ rewriter.eraseOp(freemem);
+ return mlir::success();
+ }
+};
+} // namespace
+
// Convert subcomponent array indices from column-major to row-major ordering.
static llvm::SmallVector<mlir::Value>
convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
@@ -4274,7 +4373,8 @@ void fir::populateFIRToLLVMConversionPatterns(
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion,
- NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion,
+ NoReassocOpConversion, OmpTargetAllocMemOpConversion,
+ OmpTargetFreeMemOpConversion, SelectCaseOpConversion, SelectOpConversion,
SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
StoreOpConversion, StringLitOpConversion, SubcOpConversion,
>From e41a2c76786538fc411e104542a0282cfebef4f7 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 26 Jun 2025 10:20:36 +0530
Subject: [PATCH 2/2] [flang] Fix parsing and printing.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 13 ++-
flang/lib/Optimizer/Dialect/FIROps.cpp | 90 ++++++++++++++++---
flang/test/Fir/omp_target_allocmem.fir | 28 ++++++
flang/test/Fir/omp_target_freemem.fir | 28 ++++++
4 files changed, 145 insertions(+), 14 deletions(-)
create mode 100644 flang/test/Fir/omp_target_allocmem.fir
create mode 100644 flang/test/Fir/omp_target_freemem.fir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2dff0f05fade7..666b66a8670d6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -528,7 +528,8 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
be paired with `omp_target_freemem` operations to avoid memory leaks.
```
- %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
```
}];
@@ -542,6 +543,9 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
);
let results = (outs fir_HeapType);
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
let extraClassDeclaration = [{
mlir::Type getAllocatedType();
bool hasLenParams() { return !getTypeparams().empty(); }
@@ -563,9 +567,9 @@ def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
The memory object that is deallocated is placed in an undefined state
after `fir.omp_target_freemem`.
```
- %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
- ...
- "fir.omp_target_freemem"(%device, %0) : (i32, !fir.heap<!fir.array<?xf32>>) -> ()
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
+ fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<?xf32>>
```
}];
@@ -573,6 +577,7 @@ def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
Arg<AnyIntegerType, "", [MemFree]>:$device,
Arg<fir_HeapType, "", [MemFree]>:$heapref
);
+ let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
}
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ecfa2939e96a6..9335a4b041ac8 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -106,24 +106,38 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) {
return false;
}
-/// Parser shared by Alloca and Allocmem
-///
+/// Parser shared by Alloca, Allocmem and OmpTargetAllocmem
+/// boolean flag isTargetOp is used to identify omp_target_allocmem
/// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type
/// ( `(` $typeparams `)` )? ( `,` $shape )?
/// attr-dict-without-keyword
+/// operation ::= %res = (`fir.omp_target_alloca`) $device : devicetype,
+/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+/// attr-dict-without-keyword
template <typename FN>
-static mlir::ParseResult parseAllocatableOp(FN wrapResultType,
- mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
+static mlir::ParseResult
+parseAllocatableOp(FN wrapResultType, mlir::OpAsmParser &parser,
+ mlir::OperationState &result, bool isTargetOp = false) {
+ auto &builder = parser.getBuilder();
+ bool hasOperands = false;
+ std::int32_t typeparamsSize = 0;
+ // Parse device number as a new operand
+ if (isTargetOp) {
+ mlir::OpAsmParser::UnresolvedOperand deviceOperand;
+ mlir::Type deviceType;
+ if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
+ return mlir::failure();
+ if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
+ return mlir::failure();
+ if (parser.parseComma())
+ return mlir::failure();
+ }
mlir::Type intype;
if (parser.parseType(intype))
return mlir::failure();
- auto &builder = parser.getBuilder();
result.addAttribute("in_type", mlir::TypeAttr::get(intype));
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
llvm::SmallVector<mlir::Type> typeVec;
- bool hasOperands = false;
- std::int32_t typeparamsSize = 0;
if (!parser.parseOptionalLParen()) {
// parse the LEN params of the derived type. (<params> : <types>)
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
@@ -147,13 +161,19 @@ static mlir::ParseResult parseAllocatableOp(FN wrapResultType,
parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
result.operands))
return mlir::failure();
+
mlir::Type restype = wrapResultType(intype);
if (!restype) {
parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
return mlir::failure();
}
- result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr(
- {typeparamsSize, shapeSize}));
+ llvm::SmallVector<std::int32_t> segmentSizes;
+ if (isTargetOp)
+ segmentSizes.push_back(1);
+ segmentSizes.push_back(typeparamsSize);
+ segmentSizes.push_back(shapeSize);
+ result.addAttribute("operandSegmentSizes",
+ builder.getDenseI32ArrayAttr(segmentSizes));
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.addTypeToList(restype, result.types))
return mlir::failure();
@@ -385,6 +405,56 @@ llvm::LogicalResult fir::AllocMemOp::verify() {
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// OmpTargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+mlir::Type fir::OmpTargetAllocMemOp::getAllocatedType() {
+ return mlir::cast<fir::HeapType>(getType()).getEleTy();
+}
+
+mlir::Type fir::OmpTargetAllocMemOp::getRefTy(mlir::Type ty) {
+ return fir::HeapType::get(ty);
+}
+
+mlir::ParseResult
+fir::OmpTargetAllocMemOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseAllocatableOp(wrapAllocMemResultType, parser, result, true);
+}
+
+void fir::OmpTargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getDevice());
+ p << " : ";
+ p << getDevice().getType();
+ p << ", ";
+ p << getInType();
+ if (!getTypeparams().empty()) {
+ p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
+ }
+ for (auto sh : getShape()) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"in_type", "operandSegmentSizes"});
+}
+
+llvm::LogicalResult fir::OmpTargetAllocMemOp::verify() {
+ llvm::SmallVector<llvm::StringRef> visited;
+ if (verifyInType(getInType(), visited, numShapeOperands()))
+ return emitOpError("invalid type for allocation");
+ if (verifyTypeParamCount(getInType(), numLenParams()))
+ return emitOpError("LEN params do not correspond to type");
+ mlir::Type outType = getType();
+ if (!mlir::dyn_cast<fir::HeapType>(outType))
+ return emitOpError("must be a !fir.heap type");
+ if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType)))
+ return emitOpError("cannot allocate !fir.box of unknown rank or type");
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// ArrayCoorOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/omp_target_allocmem.fir b/flang/test/Fir/omp_target_allocmem.fir
new file mode 100644
index 0000000000000..5140c91c9510c
--- /dev/null
+++ b/flang/test/Fir/omp_target_allocmem.fir
@@ -0,0 +1,28 @@
+// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
+
+// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_nonchar(
+// CHECK: call ptr @omp_target_alloc(i64 36, i32 0)
+func.func @omp_target_allocmem_array_of_nonchar() -> !fir.heap<!fir.array<3x3xi32>> {
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
+ return %1 : !fir.heap<!fir.array<3x3xi32>>
+}
+
+// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_char(
+// CHECK: call ptr @omp_target_alloc(i64 90, i32 0)
+func.func @omp_target_allocmem_array_of_char() -> !fir.heap<!fir.array<3x3x!fir.char<1,10>>> {
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>>
+ return %1 : !fir.heap<!fir.array<3x3x!fir.char<1,10>>>
+}
+
+// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_dynchar(
+// CHECK-SAME: i32 %[[len:.*]])
+// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64
+// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]]
+// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0)
+func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> !fir.heap<!fir.array<3x3x!fir.char<1,?>>> {
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
+ return %1 : !fir.heap<!fir.array<3x3x!fir.char<1,?>>>
+}
diff --git a/flang/test/Fir/omp_target_freemem.fir b/flang/test/Fir/omp_target_freemem.fir
new file mode 100644
index 0000000000000..02e136076a9cf
--- /dev/null
+++ b/flang/test/Fir/omp_target_freemem.fir
@@ -0,0 +1,28 @@
+// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar(
+// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
+func.func @omp_target_allocmem_array_of_nonchar() -> () {
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
+ fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3xi32>>
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_of_char(
+// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
+func.func @omp_target_allocmem_array_of_char() -> () {
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>>
+ fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3x!fir.char<1,10>>>
+ return
+}
+
+// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar(
+// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
+func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () {
+ %device = arith.constant 0 : i32
+ %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
+ fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3x!fir.char<1,?>>>
+ return
+}
More information about the flang-commits
mailing list