[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Jan 12 08:25:17 PST 2025


================
@@ -463,10 +465,114 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+/// Helper function to reduce a multi reduction where src and acc are splat
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
+/// (add/max/etc.)
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+                             ShapedType dstType);
----------------
banach-space wrote:

Couple of minor comments/questions.

1. Is `src add^times acc` fairly standard syntax? I would just write: `reduce(src, acc, numElems)`. Also, why not `src kind^times acc`? (i.e. re-use `kind`).

2. Missing `static`?

3. Note that this method doesn't know anything about "multi reduction". I would remove references to that and write this instead (feel free to re-use): 

```cpp
//  Computes the result of reducing a constant vector where the accumulator value, `acc`, is also constant.
OpFoldResult computeConstantReduction(T srcVal, T acc, int64_t numElemes, CombiningKind kind,
                             ShapedType dstType);
 ```
 
 I am also suggesting some re-naming.

https://github.com/llvm/llvm-project/pull/122450


More information about the Mlir-commits mailing list