[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Jan 12 08:25:18 PST 2025


================
@@ -463,10 +465,114 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+/// Helper function to reduce a multi reduction where src and acc are splat
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
+/// (add/max/etc.)
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+                             ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APFloat srcVal = src.getValue();
+  APFloat accVal = acc.getValue();
+  switch (kind) {
+  case CombiningKind::ADD: {
+    APFloat n = APFloat(srcVal.getSemantics());
+    n.convertFromAPInt(APInt(64, times, true), true,
+                       APFloat::rmNearestTiesToEven);
+    return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+  }
+  case CombiningKind::MUL: {
+    APFloat result = accVal;
+    for (int i = 0; i < times; ++i) {
+      result = result * srcVal;
+    }
+    return DenseElementsAttr::get(dstType, {result});
+  }
+  case CombiningKind::MINIMUMF:
+    return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+  case CombiningKind::MAXIMUMF:
+    return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+  case CombiningKind::MINNUMF:
+    return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+  case CombiningKind::MAXNUMF:
+    return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+  default:
+    return {};
+  }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APInt srcVal = src.getValue();
+  APInt accVal = acc.getValue();
+  switch (kind) {
+  case CombiningKind::ADD:
+    return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+  case CombiningKind::MUL: {
+    APInt result = accVal;
+    for (int i = 0; i < times; ++i) {
+      result *= srcVal;
+    }
+    return DenseElementsAttr::get(dstType, {result});
+  }
+  case CombiningKind::MINSI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.slt(srcVal) ? accVal : srcVal});
+  case CombiningKind::MAXSI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.ugt(srcVal) ? accVal : srcVal});
+  case CombiningKind::MINUI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.ult(srcVal) ? accVal : srcVal});
+  case CombiningKind::MAXUI:
+    return DenseElementsAttr::get(dstType,
+                                  {accVal.ugt(srcVal) ? accVal : srcVal});
+  case CombiningKind::AND:
+    return DenseElementsAttr::get(dstType, {accVal & srcVal});
+  case CombiningKind::OR:
+    return DenseElementsAttr::get(dstType, {accVal | srcVal});
+  case CombiningKind::XOR:
+    return DenseElementsAttr::get(dstType,
+                                  {times & 0x1 ? accVal ^ srcVal : accVal});
+  default:
+    return {};
+  }
+}
+
 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
   // Single parallel dim, this is a noop.
   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
     return getSource();
+  auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
+  auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
+  if (!srcAttr || !accAttr)
+    return {};
+  if (!srcAttr.isSplat() || !accAttr.isSplat())
+    return {};
----------------
banach-space wrote:

[nit] I would probably combine this - it's basically one condition ("are the inputs constant splats?").

https://github.com/llvm/llvm-project/pull/122450


More information about the Mlir-commits mailing list