[flang-commits] [flang] [flang][CUDA] Unify element size computation in CUF helpers (PR #167398)

Zhen Wang via flang-commits flang-commits at lists.llvm.org
Mon Nov 10 13:48:24 PST 2025


https://github.com/wangzpgi created https://github.com/llvm/llvm-project/pull/167398

Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.

>From ab0c9cb6e23280900bf0c92e3bd434f9da49ed76 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 10 Nov 2025 13:15:09 -0800
Subject: [PATCH] [flang][CUDA] Unify element size computation in CUF helpers

Refactor computeWidth from CUFOpConversion into a shared helper
function computeElementByteSize in CUFCommon.
---
 .../flang/Optimizer/Builder/CUFCommon.h       |  5 +++
 flang/lib/Optimizer/Builder/CUFCommon.cpp     | 23 +++++++++++++
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 32 +++----------------
 3 files changed, 33 insertions(+), 27 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 5c56dd6b695f8..6e2442745f9a0 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
 
 namespace fir {
 class FirOpBuilder;
+class KindMapping;
 } // namespace fir
 
 namespace cuf {
@@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
 
 void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
 
+int computeElementByteSize(mlir::Location loc, mlir::Type type,
+                           fir::KindMapping &kindMap,
+                           bool emitErrorOnFailure = true);
+
 } // namespace cuf
 
 #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index cf7588f275d22..461deb8e4cb55 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -9,6 +9,7 @@
 #include "flang/Optimizer/Builder/CUFCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
     }
   }
 }
+
+int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
+                                fir::KindMapping &kindMap,
+                                bool emitErrorOnFailure) {
+  auto eleTy = fir::unwrapSequenceType(type);
+  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
+    return t.getWidth() / 8;
+  if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
+    return t.getWidth() / 8;
+  if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
+    return kindMap.getLogicalBitsize(t.getFKind()) / 8;
+  if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
+    int elemSize =
+        mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
+    return 2 * elemSize;
+  }
+  if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
+    return kindMap.getCharacterBitsize(t.getFKind()) / 8;
+  if (emitErrorOnFailure)
+    mlir::emitError(loc, "unsupported type");
+  return 0;
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8d00272b09f42..3c3782cc234f8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
   return false;
 }
 
-static int computeWidth(mlir::Location loc, mlir::Type type,
-                        fir::KindMapping &kindMap) {
-  auto eleTy = fir::unwrapSequenceType(type);
-  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
-    return t.getWidth() / 8;
-  if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
-    return t.getWidth() / 8;
-  if (eleTy.isInteger(1))
-    return 1;
-  if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
-    return kindMap.getLogicalBitsize(t.getFKind()) / 8;
-  if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
-    int elemSize =
-        mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
-    return 2 * elemSize;
-  }
-  if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
-    return kindMap.getCharacterBitsize(t.getFKind()) / 8;
-  mlir::emitError(loc, "unsupported type");
-  return 0;
-}
-
 struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
       mlir::Value bytes;
       fir::KindMapping kindMap{fir::getKindMapping(mod)};
       if (fir::isa_trivial(op.getInType())) {
-        int width = computeWidth(loc, op.getInType(), kindMap);
+        int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
         bytes =
             builder.createIntegerConstant(loc, builder.getIndexType(), width);
       } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
           mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
           size = dl->getTypeSizeInBits(structTy) / 8;
         } else {
-          size = computeWidth(loc, seqTy.getEleTy(), kindMap);
+          size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
         }
         mlir::Value width =
             builder.createIntegerConstant(loc, builder.getIndexType(), size);
@@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion
                               const mlir::SymbolTable &symtab,
                               mlir::DataLayout *dl,
                               const fir::LLVMTypeConverter *typeConverter)
-      : OpRewritePattern(context), symtab{symtab}, dl{dl},
-        typeConverter{typeConverter} {}
+      : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{
+                                                               typeConverter} {}
 
   mlir::LogicalResult
   matchAndRewrite(cuf::DataTransferOp op,
@@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
             typeConverter->convertType(fir::unwrapSequenceType(dstTy));
         width = dl->getTypeSizeInBits(structTy) / 8;
       } else {
-        width = computeWidth(loc, dstTy, kindMap);
+        width = cuf::computeElementByteSize(loc, dstTy, kindMap);
       }
       mlir::Value widthValue = mlir::arith::ConstantOp::create(
           rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));



More information about the flang-commits mailing list