[flang-commits] [flang] 1b8a4aa - [flang][cuda] Extract element count computation into helper function (#168937)

via flang-commits flang-commits at lists.llvm.org
Thu Nov 20 13:01:26 PST 2025


Author: Zhen Wang
Date: 2025-11-20T13:01:22-08:00
New Revision: 1b8a4aa6a5cd92f06ef9c1d6705b3426107bc655

URL: https://github.com/llvm/llvm-project/commit/1b8a4aa6a5cd92f06ef9c1d6705b3426107bc655
DIFF: https://github.com/llvm/llvm-project/commit/1b8a4aa6a5cd92f06ef9c1d6705b3426107bc655.diff

LOG: [flang][cuda] Extract element count computation into helper function (#168937)

This patch extracts the common logic for computing array element counts
from shape operands into a reusable helper function in CUFCommon.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Builder/CUFCommon.h
    flang/lib/Optimizer/Builder/CUFCommon.cpp
    flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 6e2442745f9a0..98d01958846f7 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -39,6 +39,10 @@ int computeElementByteSize(mlir::Location loc, mlir::Type type,
                            fir::KindMapping &kindMap,
                            bool emitErrorOnFailure = true);
 
+mlir::Value computeElementCount(mlir::PatternRewriter &rewriter,
+                                mlir::Location loc, mlir::Value shapeOperand,
+                                mlir::Type seqType, mlir::Type targetType);
+
 } // namespace cuf
 
 #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

diff  --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index 461deb8e4cb55..2266f4d47a0cf 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
     mlir::emitError(loc, "unsupported type");
   return 0;
 }
+
+mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter,
+                                     mlir::Location loc,
+                                     mlir::Value shapeOperand,
+                                     mlir::Type seqType,
+                                     mlir::Type targetType) {
+  if (shapeOperand) {
+    // Dynamic extent - extract from shape operand
+    llvm::SmallVector<mlir::Value> extents;
+    if (auto shapeOp =
+            mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) {
+      extents = shapeOp.getExtents();
+    } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
+                   shapeOperand.getDefiningOp())) {
+      for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
+        if (i.index() & 1)
+          extents.push_back(i.value());
+    }
+
+    if (extents.empty())
+      return mlir::Value();
+
+    // Compute total element count by multiplying all dimensions
+    mlir::Value count =
+        fir::ConvertOp::create(rewriter, loc, targetType, extents[0]);
+    for (unsigned i = 1; i < extents.size(); ++i) {
+      auto operand =
+          fir::ConvertOp::create(rewriter, loc, targetType, extents[i]);
+      count = mlir::arith::MulIOp::create(rewriter, loc, count, operand);
+    }
+    return count;
+  } else {
+    // Static extent - use constant array size
+    if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
+      mlir::IntegerAttr attr =
+          rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize());
+      return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr);
+    }
+  }
+  return mlir::Value();
+}

diff  --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 5b1b0a2f6feab..caf9b7b8b38f2 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -651,31 +651,8 @@ struct CUFDataTransferOpConversion
       }
 
       mlir::Type i64Ty = builder.getI64Type();
-      mlir::Value nbElement;
-      if (op.getShape()) {
-        llvm::SmallVector<mlir::Value> extents;
-        if (auto shapeOp =
-                mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
-          extents = shapeOp.getExtents();
-        } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
-                       op.getShape().getDefiningOp())) {
-          for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
-            if (i.index() & 1)
-              extents.push_back(i.value());
-        }
-
-        nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]);
-        for (unsigned i = 1; i < extents.size(); ++i) {
-          auto operand =
-              fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]);
-          nbElement =
-              mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand);
-        }
-      } else {
-        if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
-          nbElement = builder.createIntegerConstant(
-              loc, i64Ty, seqTy.getConstantArraySize());
-      }
+      mlir::Value nbElement =
+          cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
       unsigned width = 0;
       if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
         mlir::Type structTy =


        


More information about the flang-commits mailing list