[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)

Iman Hosseini llvmlistbot at llvm.org
Tue Jan 14 05:53:55 PST 2025


================
@@ -463,10 +462,157 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+/// Computes the result of reducing a constant vector where the accumulator
+/// value, `acc`, is also constant.
+template <typename T>
+static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
+                                             CombiningKind kind,
+                                             ShapedType dstType);
+// TODO: move to APFloat, APInt headers.
+template <typename T>
+static T computePowerOf(const T &a, int64_t exponent);
+
+template <>
+APFloat computePowerOf(const APFloat &a, int64_t exponent) {
+  assert(exponent >= 0 && "negative exponents not supported.");
+  if (exponent == 0) {
+    return APFloat::getOne(a.getSemantics());
+  }
+  APFloat acc = a;
+  int64_t remainingExponent = exponent;
+  while (remainingExponent > 1) {
+    if (remainingExponent % 2 == 0) {
+      acc = acc * acc;
+      remainingExponent /= 2;
+    } else {
+      acc = acc * a;
+      remainingExponent--;
+    }
+  }
+  return acc;
+};
+
+template <>
+APInt computePowerOf(const APInt &a, int64_t exponent) {
+  assert(exponent >= 0 && "negative exponents not supported.");
+  if (exponent == 0) {
+    return APInt(a.getBitWidth(), 1);
+  }
+  APInt acc = a;
+  int64_t remainingExponent = exponent;
+  while (remainingExponent > 1) {
+    if (remainingExponent % 2 == 0) {
+      acc = acc * acc;
+      remainingExponent /= 2;
+    } else {
+      acc = acc * a;
+      remainingExponent--;
+    }
+  }
+  return acc;
+};
+
+template <>
+OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
+                                      int64_t times, CombiningKind kind,
+                                      ShapedType dstType) {
+  APFloat srcVal = src.getValue();
+  APFloat accVal = acc.getValue();
+  switch (kind) {
+  case CombiningKind::ADD: {
+    APFloat n = APFloat(srcVal.getSemantics());
+    n.convertFromAPInt(APInt(64, times, true), true,
+                       APFloat::rmNearestTiesToEven);
----------------
ImanHosseini wrote:

```
because I could see this being expanded to something like gpu.subgroup_reduce that may end up doing something very much hardware-dependent.
```
The concern is that this may make something that is runtime/hw-dependent (and it does not need to be), *not* hw dependent? Something that:
1. Does not need to be rte/hw-dependent. It's a constant.
2. There is no runtime cost to this, it would actually be _faster_ to fold the constant. 
3. There is no precision cost to it, it would in fact be _more_ accurate.
In this case, it's Splat-Splat, but in general, partial reduction does not even return consistent result *on the same hw* because the order by which it is applied may change from run-to-run. How is that desirable? How is that even consistent? If partial ordering should be canon for reductions- in what order should it be applied then? 
I've seen discussion on this where it has been a decision between being fast *or* not hw-dependent. Being fast *or* being more accurate. This is neither.
Why would we prefer to be needlessly hw-dependent, less accurate and *slower*? It's fine if some user somehow wants that- but why should it be the default?

https://github.com/llvm/llvm-project/pull/122450


More information about the Mlir-commits mailing list