[flang-commits] [flang] [flang] add simplification for ProductOp intrinsic (PR #169575)

via flang-commits flang-commits at lists.llvm.org
Tue Nov 25 14:01:17 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (stomfaig)

<details>
<summary>Changes</summary>

Add simplification for `ProductOp`, by implementing support for `ReductionConversion` and adding it to the pattern list in `SimplifyHLFIRIntrinsics` pass.

Closes: https://github.com/issues/recent?issue=llvm%7Cllvm-project%7C169433

---
Full diff: https://github.com/llvm/llvm-project/pull/169575.diff


1 Files Affected:

- (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+56) 


``````````diff
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index ce8ebaa803f47..67b43a346747d 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <type_traits>
 
 namespace hlfir {
 #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -931,6 +932,43 @@ class SumAsElementalConverter
   mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2);
 };
 
+/// Reduction converter for Product.
+class ProductAsElementalConverter
+    : public NumericReductionAsElementalConverterBase<hlfir::ProductOp> {
+  using Base = NumericReductionAsElementalConverterBase;
+
+public:
+  ProductAsElementalConverter(hlfir::ProductOp op, mlir::PatternRewriter &rewriter)
+      : Base{op, rewriter} {}
+
+
+private:
+  virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
+      [[maybe_unused]] mlir::ValueRange oneBasedIndices,
+      [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
+      final {
+    return {
+        // check element type, and use
+        // fir::factory::create{Integer or Real}Constant 
+        fir::factory::createZeroValue(builder, loc, getResultElementType())};
+  }
+  virtual llvm::SmallVector<mlir::Value>
+  reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
+                   hlfir::Entity array,
+                   mlir::ValueRange oneBasedIndices) final {
+    checkReductions(currentValue);
+    hlfir::Entity elementValue =
+        hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
+    // NOTE: we can use "Kahan summation" same way as the runtime
+    // (e.g. when fast-math is not allowed), but let's start with
+    // the simple version.
+    return {genScalarMult(currentValue[0], elementValue)};
+  }
+
+  // Generate scalar addition of the two values (of the same data type).
+  mlir::Value genScalarMult(mlir::Value value1, mlir::Value value2);
+};
+
 /// Base class for logical reductions like ALL, ANY, COUNT.
 /// They do not have MASK and FastMathFlags.
 template <typename OpT>
@@ -1194,6 +1232,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1,
   llvm_unreachable("unsupported SUM reduction type");
 }
 
+mlir::Value ProductAsElementalConverter::genScalarMult(mlir::Value value1,
+                                                  mlir::Value value2) {
+  mlir::Type ty = value1.getType();
+  assert(ty == value2.getType() && "reduction values' types do not match");
+  if (mlir::isa<mlir::FloatType>(ty))
+    return mlir::arith::MulFOp::create(builder, loc, value1, value2);
+  else if (mlir::isa<mlir::ComplexType>(ty))
+    return fir::MulcOp::create(builder, loc, value1, value2);
+  else if (mlir::isa<mlir::IntegerType>(ty))
+    return mlir::arith::MulIOp::create(builder, loc, value1, value2);
+
+  llvm_unreachable("unsupported MUL reduction type");
+}
+
 mlir::Value ReductionAsElementalConverter::genMaskValue(
     mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) {
   mlir::OpBuilder::InsertionGuard guard(builder);
@@ -1265,6 +1317,9 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
     } else if constexpr (std::is_same_v<Op, hlfir::SumOp>) {
       SumAsElementalConverter converter{op, rewriter};
       return converter.convert();
+    } else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) {
+      ProductAsElementalConverter converter{op, rewriter};
+      return converter.convert();
     }
     return rewriter.notifyMatchFailure(op, "unexpected reduction operation");
   }
@@ -3158,6 +3213,7 @@ class SimplifyHLFIRIntrinsics
     mlir::RewritePatternSet patterns(context);
     patterns.insert<TransposeAsElementalConversion>(context);
     patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
+    patterns.insert<ReductionConversion<hlfir::ProductOp>>(context);
     patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
     patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
     patterns.insert<CmpCharOpConversion>(context);

``````````

</details>


https://github.com/llvm/llvm-project/pull/169575


More information about the flang-commits mailing list