[flang-commits] [flang] [flang] Generalized simplification of HLFIR reduction ops. (PR #136071)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Apr 17 04:16:03 PDT 2025


================
@@ -173,245 +173,928 @@ class TransposeAsElementalConversion
   }
 };
 
-// Expand the SUM(DIM=CONSTANT) operation into .
-class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+/// CRTP class for converting reduction-like operations into
+/// a reduction loop[-nest] optionally wrapped into hlfir.elemental.
+/// It is used to handle operations produced for ALL, ANY, COUNT,
+/// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics.
+///
+/// All of these operations take an input array, and optional
+/// dim, mask arguments. ALL, ANY, COUNT do not have mask argument.
+template <typename T>
+class ReductionAsElementalConverter {
 public:
-  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+  ReductionAsElementalConverter(mlir::Operation *op,
+                                mlir::PatternRewriter &rewriter)
+      : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} {
+    assert(op->getNumResults() == 1);
+  }
 
-  llvm::LogicalResult
-  matchAndRewrite(hlfir::SumOp sum,
-                  mlir::PatternRewriter &rewriter) const override {
-    hlfir::Entity array = hlfir::Entity{sum.getArray()};
-    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
-    mlir::Value dim = sum.getDim();
+  /// Do the actual conversion or return mlir::failure(),
+  /// if conversion is not possible.
+  mlir::LogicalResult convert();
+
+private:
+  /// Return an instance of the derived class that implements
+  /// the interface.
+  T &impl() { return *static_cast<T *>(this); }
+  const T &impl() const { return *static_cast<const T *>(this); }
+
+  // Return fir.shape specifying the shape of the result
+  // of a reduction with DIM=dimVal. The second return value
+  // is the extent of the DIM dimension.
+  std::tuple<mlir::Value, mlir::Value>
+  genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal);
+
+  /// \p mask is a scalar or array logical mask.
+  /// If \p isPresentPred is not nullptr, it is a dynamic predicate value
+  /// identifying whether the mask's variable is present.
+  /// \p indices is a range of one-based indices to access \p mask
+  /// when it is an array.
+  ///
+  /// The method returns the scalar mask value to guard the access
+  /// to a single element of the input array.
+  mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred,
+                           mlir::ValueRange indices);
+
+protected:
+  // Methods below must be implemented by the derived type.
+
+  /// Return the input array.
+  mlir::Value getSource() const {
+    llvm_unreachable("derived type must provide getSource()");
+  }
+
+  /// Return DIM or nullptr, if it is not present.
+  mlir::Value getDim() const {
+    llvm_unreachable("derived type must provide getDim()");
+  }
+
+  /// Return MASK or nullptr, if it is not present.
+  mlir::Value getMask() const {
+    llvm_unreachable("derived type must provide getMask()");
+  }
+
+  /// Return FastMathFlags attached to the operation
+  /// or arith::FastMathFlags::none, if the operation
+  /// does not support FastMathFlags (e.g. ALL, ANY, COUNT).
+  mlir::arith::FastMathFlags getFastMath() const {
+    llvm_unreachable("derived type must provide getFastMath()");
+  }
+
+  /// Generates initial values for the reduction values used
+  /// by the reduction loop. In general, there is a single
+  /// loop-carried reduction value (e.g. for SUM), but, for example,
+  /// MAXLOC/MINLOC implementation uses multiple reductions.
+  llvm::SmallVector<mlir::Value> genReductionInitValues() {
+    llvm_unreachable("derived type must provide genReductionInitValues()");
+  }
+
+  /// Perform reduction(s) update given a single input array's element
+  /// identified by \p array and \p oneBasedIndices coordinates.
+  /// \p currentValue specifies the current value(s) of the reduction(s)
+  /// inside the reduction loop body.
+  llvm::SmallVector<mlir::Value>
+  reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
+                   hlfir::Entity array, mlir::ValueRange oneBasedIndices) {
+    llvm_unreachable("derived type must provide reduceOneElement()");
+  }
+
+  /// Given reduction value(s) in \p reductionResults produced
+  /// by the reduction loop, apply any required updates and return
+  /// new reduction value(s) to be used after the reduction loop
+  /// (e.g. as the result yield of the wrapping hlfir.elemental).
+  /// NOTE: if the reduction loop is wrapped in hlfir.elemental,
+  /// the insertion point of any generated code is inside hlfir.elemental.
+  hlfir::Entity
+  genFinalResult(const llvm::SmallVectorImpl<mlir::Value> &reductionResults) {
+    llvm_unreachable("derived type must provide genFinalResult()");
+  }
+
+  // Methods below may be shadowed by the derived type.
+
+  /// Return mlir::success(), if the operation can be converted.
+  /// The default implementation always returns mlir::success().
+  /// The derived type may shadow the default implementation
+  /// with its own definition.
+  mlir::LogicalResult isConvertible() const { return mlir::success(); }
+
+  // Default implementation of isTotalReduction() just checks
+  // if the result of the operation is a scalar.
+  // True result indicates that the reduction has to be done
+  // across all elements, false result indicates that
+  // the result is an array expression produced by an hlfir.elemental
+  // operation with a single reduction loop across the DIM dimension.
+  //
+  // MAXLOC/MINLOC must override this.
+  bool isTotalReduction() const { return getResultRank() == 0; }
+
+  // Return true, if the reduction loop[-nest] may be unordered.
+  // In general, FP reductions may only be unordered when
+  // FastMathFlags::reassoc transformations are allowed.
+  //
+  // Some dervied types may need to override this.
+  bool isUnordered() const {
+    mlir::Type elemType = getSourceElementType();
+    if (mlir::isa<mlir::IntegerType, fir::LogicalType, fir::CharacterType>(
+            elemType))
+      return true;
+    return static_cast<bool>(impl().getFastMath() &
+                             mlir::arith::FastMathFlags::reassoc);
+  }
+
+  // Methods below are utilities that are not supposed to be
+  // overridden by the derived type.
+
+  /// Return 0, if DIM is not present or its values does not matter
+  /// (for example, a reduction of 1D array does not care about
+  /// the DIM value, assuming that it is a valid program).
+  /// Return mlir::failure(), if DIM is a constant known
+  /// to be invalid for the given array.
+  /// Otherwise, return DIM constant value.
+  mlir::FailureOr<int64_t> getConstDim() const {
     int64_t dimVal = 0;
-    if (!isTotalReduction) {
+    if (!impl().isTotalReduction()) {
       // In case of partial reduction we should ignore the operations
       // with invalid DIM values. They may appear in dead code
       // after constant propagation.
-      auto constDim = fir::getIntIfConstant(dim);
+      auto constDim = fir::getIntIfConstant(impl().getDim());
       if (!constDim)
-        return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
+        return rewriter.notifyMatchFailure(op, "Nonconstant DIM");
       dimVal = *constDim;
 
-      if ((dimVal <= 0 || dimVal > array.getRank()))
-        return rewriter.notifyMatchFailure(
-            sum, "Invalid DIM for partial SUM reduction");
+      if ((dimVal <= 0 || dimVal > getSourceRank()))
+        return rewriter.notifyMatchFailure(op,
+                                           "Invalid DIM for partial reduction");
     }
+    return dimVal;
+  }
 
-    mlir::Location loc = sum.getLoc();
-    fir::FirOpBuilder builder{rewriter, sum.getOperation()};
-    mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
-    mlir::Value mask = sum.getMask();
+  /// Return hlfir::Entity of the result.
+  hlfir::Entity getResultEntity() const {
+    return hlfir::Entity{op->getResult(0)};
+  }
 
-    mlir::Value resultShape, dimExtent;
-    llvm::SmallVector<mlir::Value> arrayExtents;
-    if (isTotalReduction)
-      arrayExtents = hlfir::genExtentsVector(loc, builder, array);
-    else
-      std::tie(resultShape, dimExtent) =
-          genResultShapeForPartialReduction(loc, builder, array, dimVal);
-
-    // If the mask is present and is a scalar, then we'd better load its value
-    // outside of the reduction loop making the loop unswitching easier.
-    mlir::Value isPresentPred, maskValue;
-    if (mask) {
-      if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
-        // MASK represented by a box might be dynamically optional,
-        // so we have to check for its presence before accessing it.
-        isPresentPred =
-            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
-      }
+  /// Return type of the result (e.g. !hlfir.expr<?xi32>).
+  mlir::Type getResultType() const { return getResultEntity().getType(); }
 
-      if (hlfir::Entity{mask}.isScalar())
-        maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
-    }
+  /// Return the element type of the result (e.g. i32).
+  mlir::Type getResultElementType() const {
+    return hlfir::getFortranElementType(getResultType());
+  }
 
-    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
-                         mlir::ValueRange inputIndices) -> hlfir::Entity {
-      // Loop over all indices in the DIM dimension, and reduce all values.
-      // If DIM is not present, do total reduction.
-
-      // Initial value for the reduction.
-      mlir::Value reductionInitValue =
-          fir::factory::createZeroValue(builder, loc, elementType);
-
-      // The reduction loop may be unordered if FastMathFlags::reassoc
-      // transformations are allowed. The integer reduction is always
-      // unordered.
-      bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
-                         static_cast<bool>(sum.getFastmath() &
-                                           mlir::arith::FastMathFlags::reassoc);
+  /// Return rank of the result.
+  unsigned getResultRank() const { return getResultEntity().getRank(); }
 
-      llvm::SmallVector<mlir::Value> extents;
-      if (isTotalReduction)
-        extents = arrayExtents;
-      else
-        extents.push_back(
-            builder.createConvert(loc, builder.getIndexType(), dimExtent));
+  /// Return the element type of the result.
----------------
tblah wrote:

```suggestion
  /// Return the element type of the source.
```

https://github.com/llvm/llvm-project/pull/136071


More information about the flang-commits mailing list