[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)
Iman Hosseini
llvmlistbot at llvm.org
Sun Jan 12 08:57:20 PST 2025
================
@@ -463,10 +465,114 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+/// Helper function to reduce a multi reduction where src and acc are splat
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
+/// (add/max/etc.)
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+ ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+ CombiningKind kind, ShapedType dstType) {
+ APFloat srcVal = src.getValue();
+ APFloat accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD: {
+ APFloat n = APFloat(srcVal.getSemantics());
+ n.convertFromAPInt(APInt(64, times, true), true,
+ APFloat::rmNearestTiesToEven);
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+ }
+ case CombiningKind::MUL: {
+ APFloat result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result = result * srcVal;
+ }
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+ case CombiningKind::MAXIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+ case CombiningKind::MINNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+ case CombiningKind::MAXNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+ default:
+ return {};
+ }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+ CombiningKind kind, ShapedType dstType) {
+ APInt srcVal = src.getValue();
+ APInt accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD:
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+ case CombiningKind::MUL: {
+ APInt result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result *= srcVal;
+ }
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINSI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.slt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXSI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MINUI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.ult(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXUI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::AND:
+ return DenseElementsAttr::get(dstType, {accVal & srcVal});
+ case CombiningKind::OR:
+ return DenseElementsAttr::get(dstType, {accVal | srcVal});
+ case CombiningKind::XOR:
+ return DenseElementsAttr::get(dstType,
+ {times & 0x1 ? accVal ^ srcVal : accVal});
+ default:
+ return {};
+ }
+}
+
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
return getSource();
+ auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
----------------
ImanHosseini wrote:
Lagrange's `M´echanique Analytique` starts with "There are no figures at all in this work." And they weren't needed (which was the point of Lagrangian formulation btw). That's how I feel about empty lines :)
https://github.com/llvm/llvm-project/pull/122450
More information about the Mlir-commits
mailing list