[Mlir-commits] [flang] [mlir] [flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (PR #145464)

Kareem Ergawy llvmlistbot at llvm.org
Wed Jul 9 05:35:25 PDT 2025


================
@@ -125,10 +125,177 @@ struct PrivateClauseOpConversion
     return mlir::success();
   }
 };
+
+static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
+  auto module = op->getParentOfType<mlir::ModuleOp>();
+  if (mlir::LLVM::LLVMFuncOp mallocFunc =
+          module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
+    return mallocFunc;
+  mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+  auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
+  auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
+  return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+      moduleBuilder.getUnknownLoc(), "omp_target_alloc",
+      mlir::LLVM::LLVMFunctionType::get(
+          mlir::LLVM::LLVMPointerType::get(module->getContext()),
+          {i64Ty, i32Ty},
+          /*isVarArg=*/false));
+}
+
+static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
+                                    mlir::Type firType) {
+  if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
+    return converter.convertBoxTypeAsStruct(boxTy);
+  return converter.convertType(firType);
+}
+
+static llvm::SmallVector<mlir::NamedAttribute>
+addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
+                     llvm::ArrayRef<mlir::NamedAttribute> attrs,
+                     int32_t numCallOperands) {
+  llvm::SmallVector<mlir::NamedAttribute> newAttrs;
+  newAttrs.reserve(attrs.size() + 2);
+
+  for (mlir::NamedAttribute attr : attrs) {
+    if (attr.getName() != "operandSegmentSizes")
+      newAttrs.push_back(attr);
+  }
+
+  newAttrs.push_back(rewriter.getNamedAttr(
+      "operandSegmentSizes",
+      rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
+  newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
+                                           rewriter.getDenseI32ArrayAttr({})));
+  return newAttrs;
+}
+
+static mlir::LLVM::ConstantOp
+genConstantIndex(mlir::Location loc, mlir::Type ity,
+                 mlir::ConversionPatternRewriter &rewriter,
+                 std::int64_t offset) {
+  auto cattr = rewriter.getI64IntegerAttr(offset);
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+}
+
+static mlir::Value
+computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
+                       mlir::Type idxTy,
+                       mlir::ConversionPatternRewriter &rewriter,
+                       const mlir::DataLayout &dataLayout) {
+  llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
+  unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
+  std::int64_t distance = llvm::alignTo(size, alignment);
+  return genConstantIndex(loc, idxTy, rewriter, distance);
+}
+
+static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
+                                      mlir::ConversionPatternRewriter &rewriter,
+                                      mlir::Type llTy,
+                                      const mlir::DataLayout &dataLayout) {
+  return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
+}
+
+template <typename OP>
+static mlir::Value
+genAllocationScaleSize(OP op, mlir::Type ity,
+                       mlir::ConversionPatternRewriter &rewriter) {
+  mlir::Location loc = op.getLoc();
+  mlir::Type dataTy = op.getInType();
+  auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
+  fir::SequenceType::Extent constSize = 1;
+  if (seqTy) {
+    int constRows = seqTy.getConstantRows();
+    const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
+    if (constRows != static_cast<int>(shape.size())) {
+      for (auto extent : shape) {
+        if (constRows-- > 0)
+          continue;
+        if (extent != fir::SequenceType::getUnknownExtent())
+          constSize *= extent;
+      }
+    }
+  }
+
+  if (constSize != 1) {
+    mlir::Value constVal{
+        genConstantIndex(loc, ity, rewriter, constSize).getResult()};
+    return constVal;
+  }
+  return nullptr;
+}
+
+static mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
+                               mlir::Location loc,
+                               mlir::ConversionPatternRewriter &rewriter,
+                               mlir::Type ty, mlir::Value val,
+                               bool fold = false) {
+  auto valTy = val.getType();
+  // If the value was not yet lowered, lower its type so that it can
+  // be used in getPrimitiveTypeSizeInBits.
+  if (!mlir::isa<mlir::IntegerType>(valTy))
+    valTy = converter.convertType(valTy);
+  auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+  auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
+  if (fold) {
+    if (toSize < fromSize)
+      return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
+    if (toSize > fromSize)
+      return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
+  } else {
+    if (toSize < fromSize)
+      return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+    if (toSize > fromSize)
+      return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+  }
+  return val;
+}
+
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
+    : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
+  using OpenMPFIROpConversion::OpenMPFIROpConversion;
+
+  llvm::LogicalResult
+  matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::Type heapTy = allocmemOp.getAllocatedType();
+    mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp);
+    mlir::Location loc = allocmemOp.getLoc();
+    auto ity = lowerTy().indexType();
+    mlir::Type dataTy = fir::unwrapRefType(heapTy);
+    mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+    mlir::Type llvmPtrTy =
+        mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
+    if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
+      TODO(loc, "omp.target_allocmem codegen of derived type with length "
+                "parameters");
+    mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy,
+                                          lowerTy().getDataLayout());
+    if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter))
+      size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+    for (mlir::Value opnd : adaptor.getOperands().drop_front())
+      size = rewriter.create<mlir::LLVM::MulOp>(
+          loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
+    auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+    auto mallocTy =
+        mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
+    if (mallocTyWidth != ity.getIntOrFloatBitWidth())
+      size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
----------------
ergawy wrote:

I think this `if` condition is not covered by the introduced tests.

https://github.com/llvm/llvm-project/pull/145464


More information about the Mlir-commits mailing list