[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)

Iman Hosseini llvmlistbot at llvm.org
Fri Jan 10 05:18:07 PST 2025


https://github.com/ImanHosseini created https://github.com/llvm/llvm-project/pull/122450

If both source and acc are splat, constant fold the multi-reduction.

>From 0c7eebf52179c5fc86b31b93cb06f7070b28c5f4 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 13:15:54 +0000
Subject: [PATCH] [MLIR] [Vector] ConstantFold MultiDReduction if both src and
 acc are splat

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp    | 104 ++++++++++++++++++++
 mlir/test/Dialect/Vector/constant-fold.mlir |  54 ++++++++++
 2 files changed, 158 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..a23d952c5760c5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -21,11 +21,13 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IRMapping.h"
@@ -44,6 +46,7 @@
 #include "llvm/ADT/bit.h"
 
 #include <cassert>
+#include <cmath>
 #include <cstdint>
 #include <numeric>
 
@@ -463,10 +466,111 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+                             ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APFloat srcVal = src.getValue();
+  APFloat accVal = acc.getValue();
+  switch (kind) {
+    case CombiningKind::ADD:{
+      APFloat n = APFloat(srcVal.getSemantics());
+      n.convertFromAPInt(APInt(64, times, true), true,
+                             APFloat::rmNearestTiesToEven);
+      return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+    }
+    case CombiningKind::MUL: {
+      APFloat result = accVal;
+      for (int i = 0; i < times; ++i) {
+        result = result * srcVal;
+      }
+      return DenseElementsAttr::get(dstType, {result});
+    }
+    case CombiningKind::MINIMUMF:
+      return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+    case CombiningKind::MAXIMUMF:
+      return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+    case CombiningKind::MINNUMF:
+      return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+    case CombiningKind::MAXNUMF:
+      return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+    default:
+      return {};
+  }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APInt srcVal = src.getValue();
+  APInt accVal = acc.getValue();
+  switch (kind) {
+    case CombiningKind::ADD:
+      return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+    case CombiningKind::MUL: {
+      APInt result = accVal;
+      for (int i = 0; i < times; ++i) {
+        result *= srcVal;
+      }
+      return DenseElementsAttr::get(dstType, {result});
+    }
+    case CombiningKind::MINSI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.slt(srcVal) ? accVal : srcVal});
+    case CombiningKind::MAXSI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+    case CombiningKind::MINUI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.ult(srcVal) ? accVal : srcVal});
+    case CombiningKind::MAXUI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+    case CombiningKind::AND:
+      return DenseElementsAttr::get(dstType, {accVal & srcVal});
+    case CombiningKind::OR:
+      return DenseElementsAttr::get(dstType, {accVal | srcVal});
+    case CombiningKind::XOR:
+    return DenseElementsAttr::get(dstType,
+                                  {times & 0x1 ? accVal ^ srcVal : accVal});
+  default:
+    return {};
+  }
+}
+
 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
   // Single parallel dim, this is a noop.
   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
     return getSource();
+  auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
+  auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
+  if (!srcAttr || !accAttr)
+    return {};
+  if (!srcAttr.isSplat() || !accAttr.isSplat())
+    return {};
+  auto reductionDims = getReductionDims();
+  auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
+  auto srcDims = srcType.getShape();
+  int64_t times = 1;
+  for (auto dim : reductionDims) {
+    times *= srcDims[dim];
+  }
+  CombiningKind kind = getKind();
+  auto dstType = mlir::cast<ShapedType>(getDestType());
+  auto eltype = dstType.getElementType();
+  if (mlir::dyn_cast_or_null<FloatType>(eltype)) {
+    return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
+                                      accAttr.getSplatValue<FloatAttr>(), times,
+                                      kind, dstType);
+  }
+  if (mlir::dyn_cast_or_null<IntegerType>(eltype)) {
+    return foldSplatReduce<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
+                                        accAttr.getSplatValue<IntegerAttr>(),
+                                        times, kind, dstType);
+  }
   return {};
 }
 
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041bf..43c52b4b36ca53 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -11,3 +11,57 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
   %2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
   return %2 : vector<4x4xf16>
 }
+
+// CHECK-LABEL: fold_multid_reduction_f32_add
+func.func @fold_multid_reduction_f32_add() -> vector<1xf32> {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf32>
+  %0 = arith.constant dense<1.000000e+00> : vector<1x128x128xf32>
+  // CHECK: %{{.*}} = arith.constant dense<1.638400e+04> : vector<1xf32>
+  %1 = vector.multi_reduction <add>, %0, %cst_0 [1, 2] : vector<1x128x128xf32> to vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_f32_mul
+func.func @fold_multid_reduction_f32_mul() -> vector<1xf32> {
+  %cst_0 = arith.constant dense<1.000000e+00> : vector<1xf32>
+  %0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
+  // CHECK: %{{.*}} = arith.constant dense<1.600000e+01> : vector<1xf32>
+  %1 = vector.multi_reduction <mul>, %0, %cst_0 [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_add
+func.func @fold_multid_reduction_i32_add() -> vector<1xi32> {
+  %cst_1 = arith.constant dense<1> : vector<1xi32>
+  %0 = arith.constant dense<1> : vector<1x128x128xi32>
+  // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi32>
+  %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi32> to vector<1xi32>
+  return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_odd
+func.func @fold_multid_reduction_i32_xor_odd() -> vector<1xi32> {
+  %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+  %0 = arith.constant dense<0xA0A> : vector<1x3xi32>
+  // CHECK: %{{.*}} = arith.constant dense<2805> : vector<1xi32>
+  %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x3xi32> to vector<1xi32>
+  return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_even
+func.func @fold_multid_reduction_i32_xor_even() -> vector<1xi32> {
+  %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+  %0 = arith.constant dense<0xA0A> : vector<1x4xi32>
+  // CHECK: %{{.*}} = arith.constant dense<255> : vector<1xi32>
+  %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x4xi32> to vector<1xi32>
+  return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i64_add
+func.func @fold_multid_reduction_i64_add() -> vector<1xi64> {
+  %cst_1 = arith.constant dense<1> : vector<1xi64>
+  %0 = arith.constant dense<1> : vector<1x128x128xi64>
+  // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi64>
+  %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi64> to vector<1xi64>
+  return %1 : vector<1xi64>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list