[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)

Iman Hosseini llvmlistbot at llvm.org
Sun Jan 12 08:57:20 PST 2025


================
@@ -463,10 +465,114 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+/// Helper function to reduce a multi reduction where src and acc are splat
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
+/// (add/max/etc.)
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+                             ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APFloat srcVal = src.getValue();
+  APFloat accVal = acc.getValue();
+  switch (kind) {
+  case CombiningKind::ADD: {
+    APFloat n = APFloat(srcVal.getSemantics());
+    n.convertFromAPInt(APInt(64, times, true), true,
+                       APFloat::rmNearestTiesToEven);
+    return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+  }
+  case CombiningKind::MUL: {
+    APFloat result = accVal;
+    for (int i = 0; i < times; ++i) {
+      result = result * srcVal;
+    }
+    return DenseElementsAttr::get(dstType, {result});
+  }
+  case CombiningKind::MINIMUMF:
+    return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+  case CombiningKind::MAXIMUMF:
+    return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+  case CombiningKind::MINNUMF:
+    return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+  case CombiningKind::MAXNUMF:
+    return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+  default:
+    return {};
+  }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APInt srcVal = src.getValue();
+  APInt accVal = acc.getValue();
+  switch (kind) {
+  case CombiningKind::ADD:
+    return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+  case CombiningKind::MUL: {
+    APInt result = accVal;
+    for (int i = 0; i < times; ++i) {
+      result *= srcVal;
+    }
+    return DenseElementsAttr::get(dstType, {result});
+  }
+  case CombiningKind::MINSI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.slt(srcVal) ? accVal : srcVal});
+  case CombiningKind::MAXSI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.ugt(srcVal) ? accVal : srcVal});
+  case CombiningKind::MINUI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.ult(srcVal) ? accVal : srcVal});
+  case CombiningKind::MAXUI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.ugt(srcVal) ? accVal : srcVal});
+  case CombiningKind::AND:
+    return DenseElementsAttr::get(dstType, {accVal & srcVal});
+  case CombiningKind::OR:
+    return DenseElementsAttr::get(dstType, {accVal | srcVal});
+  case CombiningKind::XOR:
+    return DenseElementsAttr::get(dstType,
+                                  {times & 0x1 ? accVal ^ srcVal : accVal});
+  default:
+    return {};
+  }
+}
+
 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
   // Single parallel dim, this is a noop.
   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
     return getSource();
+  auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
----------------
ImanHosseini wrote:

Lagrange's `M´echanique Analytique` starts with "There are no figures at all in this work." And they weren't needed (which was the point of Lagrangian formulation btw). That's how I feel about empty lines :) 

https://github.com/llvm/llvm-project/pull/122450


More information about the Mlir-commits mailing list