[flang-commits] [flang] [flang] Generalized simplification of HLFIR reduction ops. (PR #136071)
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Thu Apr 17 04:16:03 PDT 2025
================
@@ -173,245 +173,928 @@ class TransposeAsElementalConversion
}
};
-// Expand the SUM(DIM=CONSTANT) operation into .
-class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+/// CRTP class for converting reduction-like operations into
+/// a reduction loop[-nest] optionally wrapped into hlfir.elemental.
+/// It is used to handle operations produced for ALL, ANY, COUNT,
+/// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics.
+///
+/// All of these operations take an input array, and optional
+/// dim, mask arguments. ALL, ANY, COUNT do not have mask argument.
+template <typename T>
+class ReductionAsElementalConverter {
public:
- using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+ ReductionAsElementalConverter(mlir::Operation *op,
+ mlir::PatternRewriter &rewriter)
+ : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} {
+ assert(op->getNumResults() == 1);
+ }
- llvm::LogicalResult
- matchAndRewrite(hlfir::SumOp sum,
- mlir::PatternRewriter &rewriter) const override {
- hlfir::Entity array = hlfir::Entity{sum.getArray()};
- bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
- mlir::Value dim = sum.getDim();
+ /// Do the actual conversion or return mlir::failure(),
+ /// if conversion is not possible.
+ mlir::LogicalResult convert();
+
+private:
+ /// Return an instance of the derived class that implements
+ /// the interface.
+ T &impl() { return *static_cast<T *>(this); }
+ const T &impl() const { return *static_cast<const T *>(this); }
+
+ // Return fir.shape specifying the shape of the result
+ // of a reduction with DIM=dimVal. The second return value
+ // is the extent of the DIM dimension.
+ std::tuple<mlir::Value, mlir::Value>
+ genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal);
+
+ /// \p mask is a scalar or array logical mask.
+ /// If \p isPresentPred is not nullptr, it is a dynamic predicate value
+ /// identifying whether the mask's variable is present.
+ /// \p indices is a range of one-based indices to access \p mask
+ /// when it is an array.
+ ///
+ /// The method returns the scalar mask value to guard the access
+ /// to a single element of the input array.
+ mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred,
+ mlir::ValueRange indices);
+
+protected:
+ // Methods below must be implemented by the derived type.
+
+ /// Return the input array.
+ mlir::Value getSource() const {
+ llvm_unreachable("derived type must provide getSource()");
+ }
----------------
tblah wrote:
nit: (if you agree) these could be made pure virtual so that it would be a compile-time error not to implement these in the base class. I don't think the performance penalty for the indirect method call would matter much here.
Or another possibility would be to implement these with overloading of free functions e.g.
```
template <typename T>
static mlir::Value getMask(T op) {
return op.getMask();
}
static mlir::Value getMask(hlfir::AllOp op) { return nullptr; }
```
https://github.com/llvm/llvm-project/pull/136071
More information about the flang-commits
mailing list