[flang-commits] [flang] [flang][CUDA] Unify element size computation in CUF helpers (PR #167398)
Zhen Wang via flang-commits
flang-commits at lists.llvm.org
Mon Nov 10 13:48:24 PST 2025
https://github.com/wangzpgi created https://github.com/llvm/llvm-project/pull/167398
Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.
>From ab0c9cb6e23280900bf0c92e3bd434f9da49ed76 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 10 Nov 2025 13:15:09 -0800
Subject: [PATCH] [flang][CUDA] Unify element size computation in CUF helpers
Refactor computeWidth from CUFOpConversion into a shared helper
function computeElementByteSize in CUFCommon.
---
.../flang/Optimizer/Builder/CUFCommon.h | 5 +++
flang/lib/Optimizer/Builder/CUFCommon.cpp | 23 +++++++++++++
.../Optimizer/Transforms/CUFOpConversion.cpp | 32 +++----------------
3 files changed, 33 insertions(+), 27 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 5c56dd6b695f8..6e2442745f9a0 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
namespace fir {
class FirOpBuilder;
+class KindMapping;
} // namespace fir
namespace cuf {
@@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
+int computeElementByteSize(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap,
+ bool emitErrorOnFailure = true);
+
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index cf7588f275d22..461deb8e4cb55 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -9,6 +9,7 @@
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
}
}
}
+
+int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap,
+ bool emitErrorOnFailure) {
+ auto eleTy = fir::unwrapSequenceType(type);
+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
+ return t.getWidth() / 8;
+ if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
+ return t.getWidth() / 8;
+ if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
+ return kindMap.getLogicalBitsize(t.getFKind()) / 8;
+ if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
+ int elemSize =
+ mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
+ return 2 * elemSize;
+ }
+ if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
+ return kindMap.getCharacterBitsize(t.getFKind()) / 8;
+ if (emitErrorOnFailure)
+ mlir::emitError(loc, "unsupported type");
+ return 0;
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8d00272b09f42..3c3782cc234f8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
-static int computeWidth(mlir::Location loc, mlir::Type type,
- fir::KindMapping &kindMap) {
- auto eleTy = fir::unwrapSequenceType(type);
- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
- return t.getWidth() / 8;
- if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
- return t.getWidth() / 8;
- if (eleTy.isInteger(1))
- return 1;
- if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
- return kindMap.getLogicalBitsize(t.getFKind()) / 8;
- if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
- int elemSize =
- mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
- return 2 * elemSize;
- }
- if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
- return kindMap.getCharacterBitsize(t.getFKind()) / 8;
- mlir::emitError(loc, "unsupported type");
- return 0;
-}
-
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
@@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::Value bytes;
fir::KindMapping kindMap{fir::getKindMapping(mod)};
if (fir::isa_trivial(op.getInType())) {
- int width = computeWidth(loc, op.getInType(), kindMap);
+ int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
bytes =
builder.createIntegerConstant(loc, builder.getIndexType(), width);
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
size = dl->getTypeSizeInBits(structTy) / 8;
} else {
- size = computeWidth(loc, seqTy.getEleTy(), kindMap);
+ size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
}
mlir::Value width =
builder.createIntegerConstant(loc, builder.getIndexType(), size);
@@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion
const mlir::SymbolTable &symtab,
mlir::DataLayout *dl,
const fir::LLVMTypeConverter *typeConverter)
- : OpRewritePattern(context), symtab{symtab}, dl{dl},
- typeConverter{typeConverter} {}
+ : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{
+ typeConverter} {}
mlir::LogicalResult
matchAndRewrite(cuf::DataTransferOp op,
@@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
typeConverter->convertType(fir::unwrapSequenceType(dstTy));
width = dl->getTypeSizeInBits(structTy) / 8;
} else {
- width = computeWidth(loc, dstTy, kindMap);
+ width = cuf::computeElementByteSize(loc, dstTy, kindMap);
}
mlir::Value widthValue = mlir::arith::ConstantOp::create(
rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
More information about the flang-commits
mailing list