[Mlir-commits] [flang] [mlir] [flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (PR #145464)

Kareem Ergawy llvmlistbot at llvm.org
Mon Jul 14 05:28:26 PDT 2025


================
@@ -125,10 +125,177 @@ struct PrivateClauseOpConversion
     return mlir::success();
   }
 };
+
+static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
+  auto module = op->getParentOfType<mlir::ModuleOp>();
+  if (mlir::LLVM::LLVMFuncOp mallocFunc =
+          module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
+    return mallocFunc;
+  mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+  auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
+  auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
+  return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+      moduleBuilder.getUnknownLoc(), "omp_target_alloc",
+      mlir::LLVM::LLVMFunctionType::get(
+          mlir::LLVM::LLVMPointerType::get(module->getContext()),
+          {i64Ty, i32Ty},
+          /*isVarArg=*/false));
+}
+
+static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
+                                    mlir::Type firType) {
+  if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
+    return converter.convertBoxTypeAsStruct(boxTy);
+  return converter.convertType(firType);
+}
+
+static llvm::SmallVector<mlir::NamedAttribute>
+addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
+                     llvm::ArrayRef<mlir::NamedAttribute> attrs,
+                     int32_t numCallOperands) {
+  llvm::SmallVector<mlir::NamedAttribute> newAttrs;
+  newAttrs.reserve(attrs.size() + 2);
+
+  for (mlir::NamedAttribute attr : attrs) {
+    if (attr.getName() != "operandSegmentSizes")
+      newAttrs.push_back(attr);
+  }
+
+  newAttrs.push_back(rewriter.getNamedAttr(
+      "operandSegmentSizes",
+      rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
+  newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
+                                           rewriter.getDenseI32ArrayAttr({})));
+  return newAttrs;
+}
+
+static mlir::LLVM::ConstantOp
+genConstantIndex(mlir::Location loc, mlir::Type ity,
+                 mlir::ConversionPatternRewriter &rewriter,
+                 std::int64_t offset) {
+  auto cattr = rewriter.getI64IntegerAttr(offset);
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+}
+
+static mlir::Value
+computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
+                       mlir::Type idxTy,
+                       mlir::ConversionPatternRewriter &rewriter,
+                       const mlir::DataLayout &dataLayout) {
+  llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
+  unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
+  std::int64_t distance = llvm::alignTo(size, alignment);
+  return genConstantIndex(loc, idxTy, rewriter, distance);
+}
+
+static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
+                                      mlir::ConversionPatternRewriter &rewriter,
+                                      mlir::Type llTy,
+                                      const mlir::DataLayout &dataLayout) {
+  return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
+}
+
+template <typename OP>
+static mlir::Value
+genAllocationScaleSize(OP op, mlir::Type ity,
+                       mlir::ConversionPatternRewriter &rewriter) {
+  mlir::Location loc = op.getLoc();
+  mlir::Type dataTy = op.getInType();
+  auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
+  fir::SequenceType::Extent constSize = 1;
+  if (seqTy) {
+    int constRows = seqTy.getConstantRows();
+    const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
+    if (constRows != static_cast<int>(shape.size())) {
+      for (auto extent : shape) {
+        if (constRows-- > 0)
+          continue;
+        if (extent != fir::SequenceType::getUnknownExtent())
+          constSize *= extent;
+      }
+    }
+  }
+
+  if (constSize != 1) {
+    mlir::Value constVal{
+        genConstantIndex(loc, ity, rewriter, constSize).getResult()};
+    return constVal;
+  }
+  return nullptr;
+}
+
+static mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
+                               mlir::Location loc,
+                               mlir::ConversionPatternRewriter &rewriter,
+                               mlir::Type ty, mlir::Value val,
+                               bool fold = false) {
+  auto valTy = val.getType();
+  // If the value was not yet lowered, lower its type so that it can
+  // be used in getPrimitiveTypeSizeInBits.
+  if (!mlir::isa<mlir::IntegerType>(valTy))
+    valTy = converter.convertType(valTy);
+  auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+  auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
+  if (fold) {
+    if (toSize < fromSize)
+      return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
+    if (toSize > fromSize)
+      return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
+  } else {
+    if (toSize < fromSize)
+      return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+    if (toSize > fromSize)
+      return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+  }
+  return val;
+}
+
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
----------------
ergawy wrote:

Thanks for looking into this Chaitanya. Here is a proposal on how we can do that:
```diff
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 14cc7bb511f0..36f0265e8995 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -126,22 +126,6 @@ struct PrivateClauseOpConversion
   }
 };
 
-static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
-  auto module = op->getParentOfType<mlir::ModuleOp>();
-  if (mlir::LLVM::LLVMFuncOp mallocFunc =
-          module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
-    return mallocFunc;
-  mlir::OpBuilder moduleBuilder(module.getBodyRegion());
-  auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
-  auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
-  return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
-      moduleBuilder.getUnknownLoc(), "omp_target_alloc",
-      mlir::LLVM::LLVMFunctionType::get(
-          mlir::LLVM::LLVMPointerType::get(module->getContext()),
-          {i64Ty, i32Ty},
-          /*isVarArg=*/false));
-}
-
 static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
                                     mlir::Type firType) {
   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
@@ -259,13 +243,11 @@ struct TargetAllocMemOpConversion
   matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Type heapTy = allocmemOp.getAllocatedType();
-    mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp);
     mlir::Location loc = allocmemOp.getLoc();
     auto ity = lowerTy().indexType();
     mlir::Type dataTy = fir::unwrapRefType(heapTy);
     mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
-    mlir::Type llvmPtrTy =
-        mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0);
+
     if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
       TODO(loc, "omp.target_allocmem codegen of derived type with length "
                 "parameters");
@@ -273,6 +255,7 @@ struct TargetAllocMemOpConversion
                                           lowerTy().getDataLayout());
     if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter))
       size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+
     for (mlir::Value opnd : adaptor.getOperands().drop_front())
       size = rewriter.create<mlir::LLVM::MulOp>(
           loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
@@ -281,13 +264,13 @@ struct TargetAllocMemOpConversion
         mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
     if (mallocTyWidth != ity.getIntOrFloatBitWidth())
       size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
-    allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
-    auto callOp = rewriter.create<mlir::LLVM::CallOp>(
-        loc, llvmPtrTy,
-        mlir::SmallVector<mlir::Value, 2>({size, allocmemOp.getDevice()}),
-        addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2));
-    rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(
-        allocmemOp, rewriter.getIntegerType(64), callOp.getResult());
+
+    rewriter.modifyOpInPlace(allocmemOp, [&]() {
+      allocmemOp.setInType(rewriter.getI8Type());
+      allocmemOp.getTypeparamsMutable().clear();
+      allocmemOp.getTypeparamsMutable().append(size);
+    });
+
     return mlir::success();
   }
 };
diff --git a/flang/test/Fir/omp_target_allocmem_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir
index 920220272845..00a4462510a9 100644
--- a/flang/test/Fir/omp_target_allocmem_freemem.fir
+++ b/flang/test/Fir/omp_target_allocmem_freemem.fir
@@ -24,7 +24,8 @@ func.func @omp_target_allocmem_array_of_char() -> () {
 // CHECK-SAME: i32 %[[len:.*]])
 // CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64
 // CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]]
-// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0)
+// CHECK: %[[mul3:.*]] = mul i64 1, %[[mul2]]
+// CHECK: call ptr @omp_target_alloc(i64 %[[mul3]], i32 0)
 func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () {
   %device = arith.constant 0 : i32
   %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index c246d4abbdfe..0744bac98b3f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5923,7 +5923,12 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   mlir::Type heapTy = allocMemOp.getAllocatedType();
   llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
   llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
-  llvm::ConstantInt *allocSize = builder.getInt64(typeSize.getFixedValue());
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+
+  for (auto typeParam : allocMemOp.getTypeparams())
+    allocSize =
+        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+
   // Create call to "omp_target_alloc" with the args as translated llvm values.
   llvm::CallInst *call =
       builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
```
With that, after converting `fir`, we alway have a to `omp.target_alloc_mem` op that allocates the correct number of bytes based on the converted `fir` type.

So when `TargetAllocMemOpConversion` is run, you will get something like this:
```mlir
  llvm.func @omp_target_allocmem_array_of_dynchar(%arg0: i32) {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.constant(1 : i64) : i64
    %2 = llvm.mlir.constant(9 : i64) : i64
    %3 = llvm.mul %1, %2 : i64
    %4 = llvm.sext %arg0 : i32 to i64
    %5 = llvm.mul %3, %4 : i64
    %6 = omp.target_allocmem %0 : i32, i8(%5 : i64)
    omp.target_freemem %0, %6 : i32, i64
    llvm.return
  }
```
(note that this is the conversion of the 3rd test you added in `omp_target_allocmem_freemem.fir`). You can reproduce the same result using: `./bin/fir-opt --cfg-conversion --fir-to-llvm-ir="target=aarch64-unknown-linux-gnu"` which will test only this part of the pipeline. Then you can use `mlir-translate` to test the final OpenMP to LLVM conversion.

https://github.com/llvm/llvm-project/pull/145464


More information about the Mlir-commits mailing list