[flang-commits] [flang] [flang] Generalized simplification of HLFIR reduction ops. (PR #136071)
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Thu Apr 17 04:16:03 PDT 2025
================
@@ -173,245 +173,928 @@ class TransposeAsElementalConversion
}
};
-// Expand the SUM(DIM=CONSTANT) operation into .
-class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+/// CRTP class for converting reduction-like operations into
+/// a reduction loop[-nest] optionally wrapped into hlfir.elemental.
+/// It is used to handle operations produced for ALL, ANY, COUNT,
+/// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics.
+///
+/// All of these operations take an input array, and optional
+/// dim, mask arguments. ALL, ANY, COUNT do not have mask argument.
+template <typename T>
+class ReductionAsElementalConverter {
public:
- using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+ ReductionAsElementalConverter(mlir::Operation *op,
+ mlir::PatternRewriter &rewriter)
+ : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} {
+ assert(op->getNumResults() == 1);
+ }
- llvm::LogicalResult
- matchAndRewrite(hlfir::SumOp sum,
- mlir::PatternRewriter &rewriter) const override {
- hlfir::Entity array = hlfir::Entity{sum.getArray()};
- bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
- mlir::Value dim = sum.getDim();
+ /// Do the actual conversion or return mlir::failure(),
+ /// if conversion is not possible.
+ mlir::LogicalResult convert();
+
+private:
+ /// Return an instance of the derived class that implements
+ /// the interface.
+ T &impl() { return *static_cast<T *>(this); }
+ const T &impl() const { return *static_cast<const T *>(this); }
+
+ // Return fir.shape specifying the shape of the result
+ // of a reduction with DIM=dimVal. The second return value
+ // is the extent of the DIM dimension.
+ std::tuple<mlir::Value, mlir::Value>
+ genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal);
+
+ /// \p mask is a scalar or array logical mask.
+ /// If \p isPresentPred is not nullptr, it is a dynamic predicate value
+ /// identifying whether the mask's variable is present.
+ /// \p indices is a range of one-based indices to access \p mask
+ /// when it is an array.
+ ///
+ /// The method returns the scalar mask value to guard the access
+ /// to a single element of the input array.
+ mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred,
+ mlir::ValueRange indices);
+
+protected:
+ // Methods below must be implemented by the derived type.
+
+ /// Return the input array.
+ mlir::Value getSource() const {
+ llvm_unreachable("derived type must provide getSource()");
+ }
+
+ /// Return DIM or nullptr, if it is not present.
+ mlir::Value getDim() const {
+ llvm_unreachable("derived type must provide getDim()");
+ }
+
+ /// Return MASK or nullptr, if it is not present.
+ mlir::Value getMask() const {
+ llvm_unreachable("derived type must provide getMask()");
+ }
+
+ /// Return FastMathFlags attached to the operation
+ /// or arith::FastMathFlags::none, if the operation
+ /// does not support FastMathFlags (e.g. ALL, ANY, COUNT).
+ mlir::arith::FastMathFlags getFastMath() const {
+ llvm_unreachable("derived type must provide getFastMath()");
+ }
+
+ /// Generates initial values for the reduction values used
+ /// by the reduction loop. In general, there is a single
+ /// loop-carried reduction value (e.g. for SUM), but, for example,
+ /// MAXLOC/MINLOC implementation uses multiple reductions.
+ llvm::SmallVector<mlir::Value> genReductionInitValues() {
+ llvm_unreachable("derived type must provide genReductionInitValues()");
+ }
+
+ /// Perform reduction(s) update given a single input array's element
+ /// identified by \p array and \p oneBasedIndices coordinates.
+ /// \p currentValue specifies the current value(s) of the reduction(s)
+ /// inside the reduction loop body.
+ llvm::SmallVector<mlir::Value>
+ reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue,
+ hlfir::Entity array, mlir::ValueRange oneBasedIndices) {
+ llvm_unreachable("derived type must provide reduceOneElement()");
+ }
+
+ /// Given reduction value(s) in \p reductionResults produced
+ /// by the reduction loop, apply any required updates and return
+ /// new reduction value(s) to be used after the reduction loop
+ /// (e.g. as the result yield of the wrapping hlfir.elemental).
+ /// NOTE: if the reduction loop is wrapped in hlfir.elemental,
+ /// the insertion point of any generated code is inside hlfir.elemental.
+ hlfir::Entity
+ genFinalResult(const llvm::SmallVectorImpl<mlir::Value> &reductionResults) {
+ llvm_unreachable("derived type must provide genFinalResult()");
+ }
+
+ // Methods below may be shadowed by the derived type.
+
+ /// Return mlir::success(), if the operation can be converted.
+ /// The default implementation always returns mlir::success().
+ /// The derived type may shadow the default implementation
+ /// with its own definition.
+ mlir::LogicalResult isConvertible() const { return mlir::success(); }
+
+ // Default implementation of isTotalReduction() just checks
+ // if the result of the operation is a scalar.
+ // True result indicates that the reduction has to be done
+ // across all elements, false result indicates that
+ // the result is an array expression produced by an hlfir.elemental
+ // operation with a single reduction loop across the DIM dimension.
+ //
+ // MAXLOC/MINLOC must override this.
+ bool isTotalReduction() const { return getResultRank() == 0; }
+
+ // Return true, if the reduction loop[-nest] may be unordered.
+ // In general, FP reductions may only be unordered when
+ // FastMathFlags::reassoc transformations are allowed.
+ //
+ // Some dervied types may need to override this.
+ bool isUnordered() const {
+ mlir::Type elemType = getSourceElementType();
+ if (mlir::isa<mlir::IntegerType, fir::LogicalType, fir::CharacterType>(
+ elemType))
+ return true;
+ return static_cast<bool>(impl().getFastMath() &
+ mlir::arith::FastMathFlags::reassoc);
+ }
+
+ // Methods below are utilities that are not supposed to be
+ // overridden by the derived type.
+
+ /// Return 0, if DIM is not present or its values does not matter
+ /// (for example, a reduction of 1D array does not care about
+ /// the DIM value, assuming that it is a valid program).
+ /// Return mlir::failure(), if DIM is a constant known
+ /// to be invalid for the given array.
+ /// Otherwise, return DIM constant value.
+ mlir::FailureOr<int64_t> getConstDim() const {
int64_t dimVal = 0;
- if (!isTotalReduction) {
+ if (!impl().isTotalReduction()) {
// In case of partial reduction we should ignore the operations
// with invalid DIM values. They may appear in dead code
// after constant propagation.
- auto constDim = fir::getIntIfConstant(dim);
+ auto constDim = fir::getIntIfConstant(impl().getDim());
if (!constDim)
- return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
+ return rewriter.notifyMatchFailure(op, "Nonconstant DIM");
dimVal = *constDim;
- if ((dimVal <= 0 || dimVal > array.getRank()))
- return rewriter.notifyMatchFailure(
- sum, "Invalid DIM for partial SUM reduction");
+ if ((dimVal <= 0 || dimVal > getSourceRank()))
+ return rewriter.notifyMatchFailure(op,
+ "Invalid DIM for partial reduction");
}
+ return dimVal;
+ }
- mlir::Location loc = sum.getLoc();
- fir::FirOpBuilder builder{rewriter, sum.getOperation()};
- mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
- mlir::Value mask = sum.getMask();
+ /// Return hlfir::Entity of the result.
+ hlfir::Entity getResultEntity() const {
+ return hlfir::Entity{op->getResult(0)};
+ }
- mlir::Value resultShape, dimExtent;
- llvm::SmallVector<mlir::Value> arrayExtents;
- if (isTotalReduction)
- arrayExtents = hlfir::genExtentsVector(loc, builder, array);
- else
- std::tie(resultShape, dimExtent) =
- genResultShapeForPartialReduction(loc, builder, array, dimVal);
-
- // If the mask is present and is a scalar, then we'd better load its value
- // outside of the reduction loop making the loop unswitching easier.
- mlir::Value isPresentPred, maskValue;
- if (mask) {
- if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
- // MASK represented by a box might be dynamically optional,
- // so we have to check for its presence before accessing it.
- isPresentPred =
- builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
- }
+ /// Return type of the result (e.g. !hlfir.expr<?xi32>).
+ mlir::Type getResultType() const { return getResultEntity().getType(); }
- if (hlfir::Entity{mask}.isScalar())
- maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
- }
+ /// Return the element type of the result (e.g. i32).
+ mlir::Type getResultElementType() const {
+ return hlfir::getFortranElementType(getResultType());
+ }
- auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::ValueRange inputIndices) -> hlfir::Entity {
- // Loop over all indices in the DIM dimension, and reduce all values.
- // If DIM is not present, do total reduction.
-
- // Initial value for the reduction.
- mlir::Value reductionInitValue =
- fir::factory::createZeroValue(builder, loc, elementType);
-
- // The reduction loop may be unordered if FastMathFlags::reassoc
- // transformations are allowed. The integer reduction is always
- // unordered.
- bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
- static_cast<bool>(sum.getFastmath() &
- mlir::arith::FastMathFlags::reassoc);
+ /// Return rank of the result.
+ unsigned getResultRank() const { return getResultEntity().getRank(); }
- llvm::SmallVector<mlir::Value> extents;
- if (isTotalReduction)
- extents = arrayExtents;
- else
- extents.push_back(
- builder.createConvert(loc, builder.getIndexType(), dimExtent));
+ /// Return the element type of the result.
----------------
tblah wrote:
```suggestion
/// Return the element type of the source.
```
https://github.com/llvm/llvm-project/pull/136071
More information about the flang-commits
mailing list