[flang-commits] [flang] [flang][cuda] Extract element count computation into helper function (PR #168937)
Zhen Wang via flang-commits
flang-commits at lists.llvm.org
Thu Nov 20 11:37:54 PST 2025
https://github.com/wangzpgi created https://github.com/llvm/llvm-project/pull/168937
This patch extracts the common logic for computing array element counts from shape operands into a reusable helper function in CUFCommon.
>From 3dd6b91f6c14f76ec14b8b1a75e87c14ca0a9291 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Thu, 20 Nov 2025 11:34:49 -0800
Subject: [PATCH] Extract element count computation into helper function
---
.../flang/Optimizer/Builder/CUFCommon.h | 4 ++
flang/lib/Optimizer/Builder/CUFCommon.cpp | 41 +++++++++++++++++++
.../Optimizer/Transforms/CUFOpConversion.cpp | 27 +-----------
3 files changed, 47 insertions(+), 25 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 6e2442745f9a0..98d01958846f7 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -39,6 +39,10 @@ int computeElementByteSize(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap,
bool emitErrorOnFailure = true);
+mlir::Value computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Value shapeOperand,
+ mlir::Type seqType, mlir::Type targetType);
+
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index 461deb8e4cb55..2266f4d47a0cf 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
mlir::emitError(loc, "unsupported type");
return 0;
}
+
+mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc,
+ mlir::Value shapeOperand,
+ mlir::Type seqType,
+ mlir::Type targetType) {
+ if (shapeOperand) {
+ // Dynamic extent - extract from shape operand
+ llvm::SmallVector<mlir::Value> extents;
+ if (auto shapeOp =
+ mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) {
+ extents = shapeOp.getExtents();
+ } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
+ shapeOperand.getDefiningOp())) {
+ for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
+ if (i.index() & 1)
+ extents.push_back(i.value());
+ }
+
+ if (extents.empty())
+ return mlir::Value();
+
+ // Compute total element count by multiplying all dimensions
+ mlir::Value count =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[0]);
+ for (unsigned i = 1; i < extents.size(); ++i) {
+ auto operand =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[i]);
+ count = mlir::arith::MulIOp::create(rewriter, loc, count, operand);
+ }
+ return count;
+ } else {
+ // Static extent - use constant array size
+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
+ mlir::IntegerAttr attr =
+ rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize());
+ return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr);
+ }
+ }
+ return mlir::Value();
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 5b1b0a2f6feab..caf9b7b8b38f2 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -651,31 +651,8 @@ struct CUFDataTransferOpConversion
}
mlir::Type i64Ty = builder.getI64Type();
- mlir::Value nbElement;
- if (op.getShape()) {
- llvm::SmallVector<mlir::Value> extents;
- if (auto shapeOp =
- mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
- extents = shapeOp.getExtents();
- } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
- op.getShape().getDefiningOp())) {
- for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
- if (i.index() & 1)
- extents.push_back(i.value());
- }
-
- nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]);
- for (unsigned i = 1; i < extents.size(); ++i) {
- auto operand =
- fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]);
- nbElement =
- mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand);
- }
- } else {
- if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
- nbElement = builder.createIntegerConstant(
- loc, i64Ty, seqTy.getConstantArraySize());
- }
+ mlir::Value nbElement =
+ cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
unsigned width = 0;
if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
mlir::Type structTy =
More information about the flang-commits
mailing list