[flang-commits] [flang] [flang] Generalized simplification of HLFIR reduction ops. (PR #136071)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Apr 17 08:55:44 PDT 2025


================
@@ -173,245 +173,928 @@ class TransposeAsElementalConversion
   }
 };
 
-// Expand the SUM(DIM=CONSTANT) operation into .
-class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+/// CRTP class for converting reduction-like operations into
+/// a reduction loop[-nest] optionally wrapped into hlfir.elemental.
+/// It is used to handle operations produced for ALL, ANY, COUNT,
+/// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics.
+///
+/// All of these operations take an input array, and optional
+/// dim, mask arguments. ALL, ANY, COUNT do not have mask argument.
+template <typename T>
+class ReductionAsElementalConverter {
 public:
-  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+  ReductionAsElementalConverter(mlir::Operation *op,
+                                mlir::PatternRewriter &rewriter)
+      : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} {
+    assert(op->getNumResults() == 1);
+  }
 
-  llvm::LogicalResult
-  matchAndRewrite(hlfir::SumOp sum,
-                  mlir::PatternRewriter &rewriter) const override {
-    hlfir::Entity array = hlfir::Entity{sum.getArray()};
-    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
-    mlir::Value dim = sum.getDim();
+  /// Do the actual conversion or return mlir::failure(),
+  /// if conversion is not possible.
+  mlir::LogicalResult convert();
+
+private:
+  /// Return an instance of the derived class that implements
+  /// the interface.
+  T &impl() { return *static_cast<T *>(this); }
+  const T &impl() const { return *static_cast<const T *>(this); }
+
+  // Return fir.shape specifying the shape of the result
+  // of a reduction with DIM=dimVal. The second return value
+  // is the extent of the DIM dimension.
+  std::tuple<mlir::Value, mlir::Value>
+  genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal);
+
+  /// \p mask is a scalar or array logical mask.
+  /// If \p isPresentPred is not nullptr, it is a dynamic predicate value
+  /// identifying whether the mask's variable is present.
+  /// \p indices is a range of one-based indices to access \p mask
+  /// when it is an array.
+  ///
+  /// The method returns the scalar mask value to guard the access
+  /// to a single element of the input array.
+  mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred,
+                           mlir::ValueRange indices);
+
+protected:
+  // Methods below must be implemented by the derived type.
+
+  /// Return the input array.
+  mlir::Value getSource() const {
+    llvm_unreachable("derived type must provide getSource()");
+  }
----------------
tblah wrote:

+1 this is a very gentle suggestion, please feel free to ignore.

https://github.com/llvm/llvm-project/pull/136071


More information about the flang-commits mailing list