[flang-commits] [flang] 1b8a4aa - [flang][cuda] Extract element count computation into helper function (#168937)
via flang-commits
flang-commits at lists.llvm.org
Thu Nov 20 13:01:26 PST 2025
Author: Zhen Wang
Date: 2025-11-20T13:01:22-08:00
New Revision: 1b8a4aa6a5cd92f06ef9c1d6705b3426107bc655
URL: https://github.com/llvm/llvm-project/commit/1b8a4aa6a5cd92f06ef9c1d6705b3426107bc655
DIFF: https://github.com/llvm/llvm-project/commit/1b8a4aa6a5cd92f06ef9c1d6705b3426107bc655.diff
LOG: [flang][cuda] Extract element count computation into helper function (#168937)
This patch extracts the common logic for computing array element counts
from shape operands into a reusable helper function in CUFCommon.
Added:
Modified:
flang/include/flang/Optimizer/Builder/CUFCommon.h
flang/lib/Optimizer/Builder/CUFCommon.cpp
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 6e2442745f9a0..98d01958846f7 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -39,6 +39,10 @@ int computeElementByteSize(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap,
bool emitErrorOnFailure = true);
+mlir::Value computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Value shapeOperand,
+ mlir::Type seqType, mlir::Type targetType);
+
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index 461deb8e4cb55..2266f4d47a0cf 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
mlir::emitError(loc, "unsupported type");
return 0;
}
+
+mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc,
+ mlir::Value shapeOperand,
+ mlir::Type seqType,
+ mlir::Type targetType) {
+ if (shapeOperand) {
+ // Dynamic extent - extract from shape operand
+ llvm::SmallVector<mlir::Value> extents;
+ if (auto shapeOp =
+ mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) {
+ extents = shapeOp.getExtents();
+ } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
+ shapeOperand.getDefiningOp())) {
+ for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
+ if (i.index() & 1)
+ extents.push_back(i.value());
+ }
+
+ if (extents.empty())
+ return mlir::Value();
+
+ // Compute total element count by multiplying all dimensions
+ mlir::Value count =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[0]);
+ for (unsigned i = 1; i < extents.size(); ++i) {
+ auto operand =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[i]);
+ count = mlir::arith::MulIOp::create(rewriter, loc, count, operand);
+ }
+ return count;
+ } else {
+ // Static extent - use constant array size
+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
+ mlir::IntegerAttr attr =
+ rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize());
+ return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr);
+ }
+ }
+ return mlir::Value();
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 5b1b0a2f6feab..caf9b7b8b38f2 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -651,31 +651,8 @@ struct CUFDataTransferOpConversion
}
mlir::Type i64Ty = builder.getI64Type();
- mlir::Value nbElement;
- if (op.getShape()) {
- llvm::SmallVector<mlir::Value> extents;
- if (auto shapeOp =
- mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
- extents = shapeOp.getExtents();
- } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
- op.getShape().getDefiningOp())) {
- for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
- if (i.index() & 1)
- extents.push_back(i.value());
- }
-
- nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]);
- for (unsigned i = 1; i < extents.size(); ++i) {
- auto operand =
- fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]);
- nbElement =
- mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand);
- }
- } else {
- if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
- nbElement = builder.createIntegerConstant(
- loc, i64Ty, seqTy.getConstantArraySize());
- }
+ mlir::Value nbElement =
+ cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
unsigned width = 0;
if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
mlir::Type structTy =
More information about the flang-commits
mailing list