[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)
Diego Caballero
llvmlistbot at llvm.org
Mon Jan 13 11:34:39 PST 2025
================
@@ -463,10 +462,157 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+/// Computes the result of reducing a constant vector where the accumulator
+/// value, `acc`, is also constant.
+template <typename T>
+static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
+ CombiningKind kind,
+ ShapedType dstType);
+// TODO: move to APFloat, APInt headers.
+template <typename T>
+static T computePowerOf(const T &a, int64_t exponent);
+
+template <>
+APFloat computePowerOf(const APFloat &a, int64_t exponent) {
+ assert(exponent >= 0 && "negative exponents not supported.");
+ if (exponent == 0) {
+ return APFloat::getOne(a.getSemantics());
+ }
+ APFloat acc = a;
+ int64_t remainingExponent = exponent;
+ while (remainingExponent > 1) {
+ if (remainingExponent % 2 == 0) {
+ acc = acc * acc;
+ remainingExponent /= 2;
+ } else {
+ acc = acc * a;
+ remainingExponent--;
+ }
+ }
+ return acc;
+};
+
+template <>
+APInt computePowerOf(const APInt &a, int64_t exponent) {
+ assert(exponent >= 0 && "negative exponents not supported.");
+ if (exponent == 0) {
+ return APInt(a.getBitWidth(), 1);
+ }
+ APInt acc = a;
+ int64_t remainingExponent = exponent;
+ while (remainingExponent > 1) {
+ if (remainingExponent % 2 == 0) {
+ acc = acc * acc;
+ remainingExponent /= 2;
+ } else {
+ acc = acc * a;
+ remainingExponent--;
+ }
+ }
+ return acc;
+};
+
+template <>
+OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
+ int64_t times, CombiningKind kind,
+ ShapedType dstType) {
+ APFloat srcVal = src.getValue();
+ APFloat accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD: {
+ APFloat n = APFloat(srcVal.getSemantics());
+ n.convertFromAPInt(APInt(64, times, true), true,
+ APFloat::rmNearestTiesToEven);
----------------
dcaballe wrote:
Would this be a problem is the user expects a non-default rounding mode? We have been adding FMF and RM bottom-up in the IR but it's lacking at vector level so I'm wondering if this would lead to an unexpected outcome. Perhaps @chelini, @kuhar could provide some feedback?
Worse case, I guess we could enable this folder under a flag...
https://github.com/llvm/llvm-project/pull/122450
More information about the Mlir-commits
mailing list