[Mlir-commits] [flang] [mlir] [flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (PR #145464)
Kareem Ergawy
llvmlistbot at llvm.org
Wed Jul 9 05:35:25 PDT 2025
================
@@ -125,10 +125,177 @@ struct PrivateClauseOpConversion
return mlir::success();
}
};
+
+static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ if (mlir::LLVM::LLVMFuncOp mallocFunc =
+ module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
+ return mallocFunc;
+ mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
+ auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
+ return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+ moduleBuilder.getUnknownLoc(), "omp_target_alloc",
+ mlir::LLVM::LLVMFunctionType::get(
+ mlir::LLVM::LLVMPointerType::get(module->getContext()),
+ {i64Ty, i32Ty},
+ /*isVarArg=*/false));
+}
+
+static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
+ mlir::Type firType) {
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
+ return converter.convertBoxTypeAsStruct(boxTy);
+ return converter.convertType(firType);
+}
+
+static llvm::SmallVector<mlir::NamedAttribute>
+addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs,
+ int32_t numCallOperands) {
+ llvm::SmallVector<mlir::NamedAttribute> newAttrs;
+ newAttrs.reserve(attrs.size() + 2);
+
+ for (mlir::NamedAttribute attr : attrs) {
+ if (attr.getName() != "operandSegmentSizes")
+ newAttrs.push_back(attr);
+ }
+
+ newAttrs.push_back(rewriter.getNamedAttr(
+ "operandSegmentSizes",
+ rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
+ newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
+ rewriter.getDenseI32ArrayAttr({})));
+ return newAttrs;
+}
+
+static mlir::LLVM::ConstantOp
+genConstantIndex(mlir::Location loc, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter,
+ std::int64_t offset) {
+ auto cattr = rewriter.getI64IntegerAttr(offset);
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+}
+
+static mlir::Value
+computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
+ mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ const mlir::DataLayout &dataLayout) {
+ llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
+ unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
+ std::int64_t distance = llvm::alignTo(size, alignment);
+ return genConstantIndex(loc, idxTy, rewriter, distance);
+}
+
+static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type llTy,
+ const mlir::DataLayout &dataLayout) {
+ return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
+}
+
+template <typename OP>
+static mlir::Value
+genAllocationScaleSize(OP op, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter) {
+ mlir::Location loc = op.getLoc();
+ mlir::Type dataTy = op.getInType();
+ auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
+ fir::SequenceType::Extent constSize = 1;
+ if (seqTy) {
+ int constRows = seqTy.getConstantRows();
+ const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
+ if (constRows != static_cast<int>(shape.size())) {
+ for (auto extent : shape) {
+ if (constRows-- > 0)
+ continue;
+ if (extent != fir::SequenceType::getUnknownExtent())
+ constSize *= extent;
+ }
+ }
+ }
+
+ if (constSize != 1) {
+ mlir::Value constVal{
+ genConstantIndex(loc, ity, rewriter, constSize).getResult()};
+ return constVal;
+ }
+ return nullptr;
+}
+
+static mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
+ mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type ty, mlir::Value val,
+ bool fold = false) {
+ auto valTy = val.getType();
+ // If the value was not yet lowered, lower its type so that it can
+ // be used in getPrimitiveTypeSizeInBits.
+ if (!mlir::isa<mlir::IntegerType>(valTy))
+ valTy = converter.convertType(valTy);
+ auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+ auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
+ if (fold) {
+ if (toSize < fromSize)
+ return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
+ } else {
+ if (toSize < fromSize)
+ return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+ }
+ return val;
+}
+
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
+ : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
+ using OpenMPFIROpConversion::OpenMPFIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Type heapTy = allocmemOp.getAllocatedType();
+ mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp);
+ mlir::Location loc = allocmemOp.getLoc();
+ auto ity = lowerTy().indexType();
+ mlir::Type dataTy = fir::unwrapRefType(heapTy);
+ mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+ mlir::Type llvmPtrTy =
+ mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
+ if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
+ TODO(loc, "omp.target_allocmem codegen of derived type with length "
+ "parameters");
+ mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy,
+ lowerTy().getDataLayout());
+ if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter))
+ size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+ for (mlir::Value opnd : adaptor.getOperands().drop_front())
+ size = rewriter.create<mlir::LLVM::MulOp>(
+ loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
+ auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+ auto mallocTy =
+ mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
+ if (mallocTyWidth != ity.getIntOrFloatBitWidth())
+ size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
----------------
ergawy wrote:
I think this `if` condition is not covered by the introduced tests.
https://github.com/llvm/llvm-project/pull/145464
More information about the Mlir-commits
mailing list