[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)

Diego Caballero llvmlistbot at llvm.org
Mon Jan 13 11:34:39 PST 2025


================
@@ -463,10 +462,157 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+/// Computes the result of reducing a constant vector where the accumulator
+/// value, `acc`, is also constant.
+template <typename T>
+static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
+                                             CombiningKind kind,
+                                             ShapedType dstType);
+// TODO: move to APFloat, APInt headers.
+template <typename T>
+static T computePowerOf(const T &a, int64_t exponent);
+
+template <>
+APFloat computePowerOf(const APFloat &a, int64_t exponent) {
+  assert(exponent >= 0 && "negative exponents not supported.");
+  if (exponent == 0) {
+    return APFloat::getOne(a.getSemantics());
+  }
+  APFloat acc = a;
+  int64_t remainingExponent = exponent;
+  while (remainingExponent > 1) {
+    if (remainingExponent % 2 == 0) {
+      acc = acc * acc;
+      remainingExponent /= 2;
+    } else {
+      acc = acc * a;
+      remainingExponent--;
+    }
+  }
+  return acc;
+};
+
+template <>
+APInt computePowerOf(const APInt &a, int64_t exponent) {
+  assert(exponent >= 0 && "negative exponents not supported.");
+  if (exponent == 0) {
+    return APInt(a.getBitWidth(), 1);
+  }
+  APInt acc = a;
+  int64_t remainingExponent = exponent;
+  while (remainingExponent > 1) {
+    if (remainingExponent % 2 == 0) {
+      acc = acc * acc;
+      remainingExponent /= 2;
+    } else {
+      acc = acc * a;
+      remainingExponent--;
+    }
+  }
+  return acc;
+};
+
+template <>
+OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
+                                      int64_t times, CombiningKind kind,
+                                      ShapedType dstType) {
+  APFloat srcVal = src.getValue();
+  APFloat accVal = acc.getValue();
+  switch (kind) {
+  case CombiningKind::ADD: {
+    APFloat n = APFloat(srcVal.getSemantics());
+    n.convertFromAPInt(APInt(64, times, true), true,
+                       APFloat::rmNearestTiesToEven);
----------------
dcaballe wrote:

Would this be a problem is the user expects a non-default rounding mode? We have been adding FMF and RM bottom-up in the IR but it's lacking at vector level so I'm wondering if this would lead to an unexpected outcome. Perhaps @chelini, @kuhar could provide some feedback?
Worse case, I guess we could enable this folder under a flag...

https://github.com/llvm/llvm-project/pull/122450


More information about the Mlir-commits mailing list