[flang-commits] [flang] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. (PR #118556)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 03:12:32 PST 2024


================
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion
   }
 };
 
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::SumOp sum,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = sum.getLoc();
+    fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+    assert(expr && "expected an expression type for the result of hlfir.sum");
+    mlir::Type elementType = expr.getElementType();
+    hlfir::Entity array = hlfir::Entity{sum.getArray()};
+    mlir::Value mask = sum.getMask();
+    mlir::Value dim = sum.getDim();
+    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    assert(dimVal > 0 && "DIM must be present and a positive constant");
+    mlir::Value resultShape, dimExtent;
+    std::tie(resultShape, dimExtent) =
+        genResultShape(loc, builder, array, dimVal);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      // Loop over all indices in the DIM dimension, and reduce all values.
+      // We do not need to create the reduction loop always: if we can
+      // slice the input array given the inputIndices, then we can
+      // just apply a new SUM operation (total reduction) to the slice.
+      // For the time being, generate the explicit loop because the slicing
+      // requires generating an elemental operation for the input array
+      // (and the mask, if present).
+      // TODO: produce the slices and new SUM after adding a pattern
+      // for expanding total reduction SUM case.
+      mlir::Type indexType = builder.getIndexType();
+      auto one = builder.createIntegerConstant(loc, indexType, 1);
+      auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+      // Initial value for the reduction.
+      mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+      // The reduction loop may be unordered if FastMathFlags::reassoc
+      // transformations are allowed. The integer reduction is always
+      // unordered.
+      bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+                         static_cast<bool>(sum.getFastmath() &
+                                           mlir::arith::FastMathFlags::reassoc);
+
+      // If the mask is present and is a scalar, then we'd better load its value
+      // outside of the reduction loop making the loop unswitching easier.
+      // Maybe it is worth hoisting it from the elemental operation as well.
+      if (mask) {
+        hlfir::Entity maskValue{mask};
+        if (maskValue.isScalar())
+          mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
+      }
+
+      // NOTE: the outer elemental operation may be lowered into
+      // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
+      // loop may appear disjoint from the workshare loop nest.
+      // Moreover, the inner loop is not strictly nested (due to the reduction
+      // starting value initialization), and the above omp dialect operations
+      // cannot produce results.
+      // It is unclear what we should do about it yet.
----------------
tblah wrote:

I think this is okay. Most intrinsics are going to be evaluated in a single thread in WORKSHARE for now (which is what some other compilers do too). In this case I think SUM would be best implemented with a special rewrite pattern for openmp using a reduction clause.

In general, implementing good multithreaded versions of these intrinsics that are useful on both CPU and offloading devices is quite hard. My opinion is that we should only attempt this when there is a concrete performance case to benchmark. I wouldn't want this relatively rare openmp construct (with historically poor compiler support) to make performance work in the rest of the compiler more difficult.

https://github.com/llvm/llvm-project/pull/118556


More information about the flang-commits mailing list