[flang-commits] [flang] [flang][cuda] Allocate descriptor in managed memory on rebox block argument (PR #123971)
via flang-commits
flang-commits at lists.llvm.org
Wed Jan 22 09:10:28 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Another case where the descriptor must be allocated with the CUF runtime and not a simple alloca instruction.
---
Full diff: https://github.com/llvm/llvm-project/pull/123971.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+20-18)
- (modified) flang/test/Fir/CUDA/cuda-code-gen.mlir (+11)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 43c0e2686a8c3b..6ff2c20d744537 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter);
if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
- return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
- operands, rewriter);
- return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
- operands, rewriter);
+ return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
+ inputStrides, operands, rewriter);
+ return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
+ inputStrides, operands, rewriter);
}
private:
/// Write resulting shape and base address in descriptor, and replace rebox
/// op.
llvm::LogicalResult
- finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
- mlir::Value base, mlir::ValueRange lbounds,
- mlir::ValueRange extents, mlir::ValueRange strides,
+ finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
+ mlir::Type destBoxTy, mlir::Value dest, mlir::Value base,
+ mlir::ValueRange lbounds, mlir::ValueRange extents,
+ mlir::ValueRange strides,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
mlir::Value zero =
@@ -2075,15 +2076,15 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
dest = insertBaseAddress(rewriter, loc, dest, base);
mlir::Value result = placeInMemoryIfNotGlobalInit(
rewriter, rebox.getLoc(), destBoxTy, dest,
- isDeviceAllocation(rebox.getBox(), rebox.getBox()));
+ isDeviceAllocation(rebox.getBox(), adaptor.getBox()));
rewriter.replaceOp(rebox, result);
return mlir::success();
}
// Apply slice given the base address, extents and strides of the input box.
llvm::LogicalResult
- sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
- mlir::Value base, mlir::ValueRange inputExtents,
+ sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
+ mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
mlir::ValueRange inputStrides, mlir::ValueRange operands,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
@@ -2109,7 +2110,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
if (rebox.getSlice().empty())
// The array section is of the form array[%component][substring], keep
// the input array extents and strides.
- return finalizeRebox(rebox, destBoxTy, dest, base,
+ return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
/*lbounds*/ std::nullopt, inputExtents, inputStrides,
rewriter);
@@ -2158,15 +2159,16 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
slicedStrides.emplace_back(stride);
}
}
- return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
- slicedExtents, slicedStrides, rewriter);
+ return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
+ /*lbounds*/ std::nullopt, slicedExtents, slicedStrides,
+ rewriter);
}
/// Apply a new shape to the data described by a box given the base address,
/// extents and strides of the box.
llvm::LogicalResult
- reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
- mlir::Value base, mlir::ValueRange inputExtents,
+ reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
+ mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
mlir::ValueRange inputStrides, mlir::ValueRange operands,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::ValueRange reboxShifts{
@@ -2175,7 +2177,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
rebox.getShift().size()};
if (rebox.getShape().empty()) {
// Only setting new lower bounds.
- return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
+ return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
inputExtents, inputStrides, rewriter);
}
@@ -2199,8 +2201,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
// nextStride = extent * stride;
stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
}
- return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
- newStrides, rewriter);
+ return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
+ newExtents, newStrides, rewriter);
}
/// Return scalar element type of the input box.
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index 7ac89836a3ff16..063454799502af 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -187,3 +187,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
// CHECK-LABEL: llvm.func @_QPouter
// CHECK: _FortranACUFAllocDescriptor
+
+// -----
+
+func.func @_QMm1Psub1(%arg0: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "da"}, %arg1: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "db"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}) {
+ %0 = fircg.ext_rebox %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+ %1 = fircg.ext_rebox %arg1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+ return
+}
+
+// CHECK-LABEL: llvm.func @_QMm1Psub1
+// CHECK-COUNT-2: _FortranACUFAllocDescriptor
``````````
</details>
https://github.com/llvm/llvm-project/pull/123971
More information about the flang-commits
mailing list