[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sun Jan 12 08:25:17 PST 2025
================
@@ -463,10 +465,114 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+/// Helper function to reduce a multi reduction where src and acc are splat
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
+/// (add/max/etc.)
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+ ShapedType dstType);
----------------
banach-space wrote:
Couple of minor comments/questions.
1. Is `src add^times acc` fairly standard syntax? I would just write: `reduce(src, acc, numElems)`. Also, why not `src kind^times acc`? (i.e. re-use `kind`).
2. Missing `static`?
3. Note that this method doesn't know anything about "multi reduction". I would remove references to that and write this instead (feel free to re-use):
```cpp
// Computes the result of reducing a constant vector where the accumulator value, `acc`, is also constant.
OpFoldResult computeConstantReduction(T srcVal, T acc, int64_t numElemes, CombiningKind kind,
ShapedType dstType);
```
I am also suggesting some re-naming.
https://github.com/llvm/llvm-project/pull/122450
More information about the Mlir-commits
mailing list