[Mlir-commits] [flang] [mlir] [flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (PR #145464)
Kareem Ergawy
llvmlistbot at llvm.org
Mon Jul 14 05:28:26 PDT 2025
================
@@ -125,10 +125,177 @@ struct PrivateClauseOpConversion
return mlir::success();
}
};
+
+static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ if (mlir::LLVM::LLVMFuncOp mallocFunc =
+ module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
+ return mallocFunc;
+ mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
+ auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
+ return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+ moduleBuilder.getUnknownLoc(), "omp_target_alloc",
+ mlir::LLVM::LLVMFunctionType::get(
+ mlir::LLVM::LLVMPointerType::get(module->getContext()),
+ {i64Ty, i32Ty},
+ /*isVarArg=*/false));
+}
+
+static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
+ mlir::Type firType) {
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
+ return converter.convertBoxTypeAsStruct(boxTy);
+ return converter.convertType(firType);
+}
+
+static llvm::SmallVector<mlir::NamedAttribute>
+addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs,
+ int32_t numCallOperands) {
+ llvm::SmallVector<mlir::NamedAttribute> newAttrs;
+ newAttrs.reserve(attrs.size() + 2);
+
+ for (mlir::NamedAttribute attr : attrs) {
+ if (attr.getName() != "operandSegmentSizes")
+ newAttrs.push_back(attr);
+ }
+
+ newAttrs.push_back(rewriter.getNamedAttr(
+ "operandSegmentSizes",
+ rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
+ newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
+ rewriter.getDenseI32ArrayAttr({})));
+ return newAttrs;
+}
+
+static mlir::LLVM::ConstantOp
+genConstantIndex(mlir::Location loc, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter,
+ std::int64_t offset) {
+ auto cattr = rewriter.getI64IntegerAttr(offset);
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+}
+
+static mlir::Value
+computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
+ mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ const mlir::DataLayout &dataLayout) {
+ llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
+ unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
+ std::int64_t distance = llvm::alignTo(size, alignment);
+ return genConstantIndex(loc, idxTy, rewriter, distance);
+}
+
+static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type llTy,
+ const mlir::DataLayout &dataLayout) {
+ return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
+}
+
+template <typename OP>
+static mlir::Value
+genAllocationScaleSize(OP op, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter) {
+ mlir::Location loc = op.getLoc();
+ mlir::Type dataTy = op.getInType();
+ auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
+ fir::SequenceType::Extent constSize = 1;
+ if (seqTy) {
+ int constRows = seqTy.getConstantRows();
+ const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
+ if (constRows != static_cast<int>(shape.size())) {
+ for (auto extent : shape) {
+ if (constRows-- > 0)
+ continue;
+ if (extent != fir::SequenceType::getUnknownExtent())
+ constSize *= extent;
+ }
+ }
+ }
+
+ if (constSize != 1) {
+ mlir::Value constVal{
+ genConstantIndex(loc, ity, rewriter, constSize).getResult()};
+ return constVal;
+ }
+ return nullptr;
+}
+
+static mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
+ mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type ty, mlir::Value val,
+ bool fold = false) {
+ auto valTy = val.getType();
+ // If the value was not yet lowered, lower its type so that it can
+ // be used in getPrimitiveTypeSizeInBits.
+ if (!mlir::isa<mlir::IntegerType>(valTy))
+ valTy = converter.convertType(valTy);
+ auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+ auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
+ if (fold) {
+ if (toSize < fromSize)
+ return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
+ } else {
+ if (toSize < fromSize)
+ return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+ }
+ return val;
+}
+
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
----------------
ergawy wrote:
Thanks for looking into this Chaitanya. Here is a proposal on how we can do that:
```diff
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 14cc7bb511f0..36f0265e8995 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -126,22 +126,6 @@ struct PrivateClauseOpConversion
}
};
-static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
- auto module = op->getParentOfType<mlir::ModuleOp>();
- if (mlir::LLVM::LLVMFuncOp mallocFunc =
- module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
- return mallocFunc;
- mlir::OpBuilder moduleBuilder(module.getBodyRegion());
- auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
- auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
- return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
- moduleBuilder.getUnknownLoc(), "omp_target_alloc",
- mlir::LLVM::LLVMFunctionType::get(
- mlir::LLVM::LLVMPointerType::get(module->getContext()),
- {i64Ty, i32Ty},
- /*isVarArg=*/false));
-}
-
static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
mlir::Type firType) {
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
@@ -259,13 +243,11 @@ struct TargetAllocMemOpConversion
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type heapTy = allocmemOp.getAllocatedType();
- mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp);
mlir::Location loc = allocmemOp.getLoc();
auto ity = lowerTy().indexType();
mlir::Type dataTy = fir::unwrapRefType(heapTy);
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
- mlir::Type llvmPtrTy =
- mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
+
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
TODO(loc, "omp.target_allocmem codegen of derived type with length "
"parameters");
@@ -273,6 +255,7 @@ struct TargetAllocMemOpConversion
lowerTy().getDataLayout());
if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter))
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+
for (mlir::Value opnd : adaptor.getOperands().drop_front())
size = rewriter.create<mlir::LLVM::MulOp>(
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
@@ -281,13 +264,13 @@ struct TargetAllocMemOpConversion
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
- allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
- auto callOp = rewriter.create<mlir::LLVM::CallOp>(
- loc, llvmPtrTy,
- mlir::SmallVector<mlir::Value, 2>({size, allocmemOp.getDevice()}),
- addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2));
- rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(
- allocmemOp, rewriter.getIntegerType(64), callOp.getResult());
+
+ rewriter.modifyOpInPlace(allocmemOp, [&]() {
+ allocmemOp.setInType(rewriter.getI8Type());
+ allocmemOp.getTypeparamsMutable().clear();
+ allocmemOp.getTypeparamsMutable().append(size);
+ });
+
return mlir::success();
}
};
diff --git a/flang/test/Fir/omp_target_allocmem_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir
index 920220272845..00a4462510a9 100644
--- a/flang/test/Fir/omp_target_allocmem_freemem.fir
+++ b/flang/test/Fir/omp_target_allocmem_freemem.fir
@@ -24,7 +24,8 @@ func.func @omp_target_allocmem_array_of_char() -> () {
// CHECK-SAME: i32 %[[len:.*]])
// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64
// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]]
-// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0)
+// CHECK: %[[mul3:.*]] = mul i64 1, %[[mul2]]
+// CHECK: call ptr @omp_target_alloc(i64 %[[mul3]], i32 0)
func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () {
%device = arith.constant 0 : i32
%1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index c246d4abbdfe..0744bac98b3f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5923,7 +5923,12 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
mlir::Type heapTy = allocMemOp.getAllocatedType();
llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
- llvm::ConstantInt *allocSize = builder.getInt64(typeSize.getFixedValue());
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+
+ for (auto typeParam : allocMemOp.getTypeparams())
+ allocSize =
+ builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+
// Create call to "omp_target_alloc" with the args as translated llvm values.
llvm::CallInst *call =
builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
```
With that, after converting `fir`, we alway have a to `omp.target_alloc_mem` op that allocates the correct number of bytes based on the converted `fir` type.
So when `TargetAllocMemOpConversion` is run, you will get something like this:
```mlir
llvm.func @omp_target_allocmem_array_of_dynchar(%arg0: i32) {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.constant(1 : i64) : i64
%2 = llvm.mlir.constant(9 : i64) : i64
%3 = llvm.mul %1, %2 : i64
%4 = llvm.sext %arg0 : i32 to i64
%5 = llvm.mul %3, %4 : i64
%6 = omp.target_allocmem %0 : i32, i8(%5 : i64)
omp.target_freemem %0, %6 : i32, i64
llvm.return
}
```
(note that this is the conversion of the 3rd test you added in `omp_target_allocmem_freemem.fir`). You can reproduce the same result using: `./bin/fir-opt --cfg-conversion --fir-to-llvm-ir="target=aarch64-unknown-linux-gnu"` which will test only this part of the pipeline. Then you can use `mlir-translate` to test the final OpenMP to LLVM conversion.
https://github.com/llvm/llvm-project/pull/145464
More information about the Mlir-commits
mailing list