[flang-commits] [flang] [flang] Generalized simplification of HLFIR reduction ops. (PR #136071)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Thu Apr 17 09:03:07 PDT 2025
================
@@ -173,245 +173,928 @@ class TransposeAsElementalConversion
}
};
-// Expand the SUM(DIM=CONSTANT) operation into .
-class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+/// CRTP class for converting reduction-like operations into
+/// a reduction loop[-nest] optionally wrapped into hlfir.elemental.
+/// It is used to handle operations produced for ALL, ANY, COUNT,
+/// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics.
+///
+/// All of these operations take an input array, and optional
+/// dim, mask arguments. ALL, ANY, COUNT do not have mask argument.
+template <typename T>
+class ReductionAsElementalConverter {
public:
- using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+ ReductionAsElementalConverter(mlir::Operation *op,
+ mlir::PatternRewriter &rewriter)
+ : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} {
+ assert(op->getNumResults() == 1);
+ }
- llvm::LogicalResult
- matchAndRewrite(hlfir::SumOp sum,
- mlir::PatternRewriter &rewriter) const override {
- hlfir::Entity array = hlfir::Entity{sum.getArray()};
- bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
- mlir::Value dim = sum.getDim();
+ /// Do the actual conversion or return mlir::failure(),
+ /// if conversion is not possible.
+ mlir::LogicalResult convert();
+
+private:
+ /// Return an instance of the derived class that implements
+ /// the interface.
+ T &impl() { return *static_cast<T *>(this); }
+ const T &impl() const { return *static_cast<const T *>(this); }
+
+ // Return fir.shape specifying the shape of the result
+ // of a reduction with DIM=dimVal. The second return value
+ // is the extent of the DIM dimension.
+ std::tuple<mlir::Value, mlir::Value>
+ genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal);
+
+ /// \p mask is a scalar or array logical mask.
+ /// If \p isPresentPred is not nullptr, it is a dynamic predicate value
+ /// identifying whether the mask's variable is present.
+ /// \p indices is a range of one-based indices to access \p mask
+ /// when it is an array.
+ ///
+ /// The method returns the scalar mask value to guard the access
+ /// to a single element of the input array.
+ mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred,
+ mlir::ValueRange indices);
+
+protected:
+ // Methods below must be implemented by the derived type.
+
+ /// Return the input array.
+ mlir::Value getSource() const {
+ llvm_unreachable("derived type must provide getSource()");
+ }
----------------
vzakhari wrote:
Thanks for the reviews!
I guess CRTP might be an overkill here, and a virtual interface is just enough. I used it because MLIR codes use it extensively, but it is probably mostly for the performance reasons, which is not a big issue here. Another thing is that I wanted to define overridable functions with different result types (e.g. see the different `SmallVector` sizes used for the results of some methods), and it is not straightforward to do with virtual overrides. Well, now MSVC does not like it, so I will have to do something about it :) I may end up rewriting it with a virtual interface.
https://github.com/llvm/llvm-project/pull/136071
More information about the flang-commits
mailing list