[Mlir-commits] [mlir] [MLIR] [Vector] ConstantFold MultiDReduction (PR #122450)
Iman Hosseini
llvmlistbot at llvm.org
Mon Jan 13 13:56:41 PST 2025
https://github.com/ImanHosseini updated https://github.com/llvm/llvm-project/pull/122450
>From 0c7eebf52179c5fc86b31b93cb06f7070b28c5f4 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 13:15:54 +0000
Subject: [PATCH 01/16] [MLIR] [Vector] ConstantFold MultiDReduction if both
src and acc are splat
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 104 ++++++++++++++++++++
mlir/test/Dialect/Vector/constant-fold.mlir | 54 ++++++++++
2 files changed, 158 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..a23d952c5760c5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -21,11 +21,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
@@ -44,6 +46,7 @@
#include "llvm/ADT/bit.h"
#include <cassert>
+#include <cmath>
#include <cstdint>
#include <numeric>
@@ -463,10 +466,111 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+ ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+ CombiningKind kind, ShapedType dstType) {
+ APFloat srcVal = src.getValue();
+ APFloat accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD:{
+ APFloat n = APFloat(srcVal.getSemantics());
+ n.convertFromAPInt(APInt(64, times, true), true,
+ APFloat::rmNearestTiesToEven);
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+ }
+ case CombiningKind::MUL: {
+ APFloat result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result = result * srcVal;
+ }
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+ case CombiningKind::MAXIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+ case CombiningKind::MINNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+ case CombiningKind::MAXNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+ default:
+ return {};
+ }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+ CombiningKind kind, ShapedType dstType) {
+ APInt srcVal = src.getValue();
+ APInt accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD:
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+ case CombiningKind::MUL: {
+ APInt result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result *= srcVal;
+ }
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINSI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.slt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXSI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MINUI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.ult(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXUI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::AND:
+ return DenseElementsAttr::get(dstType, {accVal & srcVal});
+ case CombiningKind::OR:
+ return DenseElementsAttr::get(dstType, {accVal | srcVal});
+ case CombiningKind::XOR:
+ return DenseElementsAttr::get(dstType,
+ {times & 0x1 ? accVal ^ srcVal : accVal});
+ default:
+ return {};
+ }
+}
+
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
return getSource();
+ auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
+ auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
+ if (!srcAttr || !accAttr)
+ return {};
+ if (!srcAttr.isSplat() || !accAttr.isSplat())
+ return {};
+ auto reductionDims = getReductionDims();
+ auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
+ auto srcDims = srcType.getShape();
+ int64_t times = 1;
+ for (auto dim : reductionDims) {
+ times *= srcDims[dim];
+ }
+ CombiningKind kind = getKind();
+ auto dstType = mlir::cast<ShapedType>(getDestType());
+ auto eltype = dstType.getElementType();
+ if (mlir::dyn_cast_or_null<FloatType>(eltype)) {
+ return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
+ accAttr.getSplatValue<FloatAttr>(), times,
+ kind, dstType);
+ }
+ if (mlir::dyn_cast_or_null<IntegerType>(eltype)) {
+ return foldSplatReduce<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
+ accAttr.getSplatValue<IntegerAttr>(),
+ times, kind, dstType);
+ }
return {};
}
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041bf..43c52b4b36ca53 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -11,3 +11,57 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
%2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
return %2 : vector<4x4xf16>
}
+
+// CHECK-LABEL: fold_multid_reduction_f32_add
+func.func @fold_multid_reduction_f32_add() -> vector<1xf32> {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<1.000000e+00> : vector<1x128x128xf32>
+ // CHECK: %{{.*}} = arith.constant dense<1.638400e+04> : vector<1xf32>
+ %1 = vector.multi_reduction <add>, %0, %cst_0 [1, 2] : vector<1x128x128xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_f32_mul
+func.func @fold_multid_reduction_f32_mul() -> vector<1xf32> {
+ %cst_0 = arith.constant dense<1.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
+ // CHECK: %{{.*}} = arith.constant dense<1.600000e+01> : vector<1xf32>
+ %1 = vector.multi_reduction <mul>, %0, %cst_0 [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_add
+func.func @fold_multid_reduction_i32_add() -> vector<1xi32> {
+ %cst_1 = arith.constant dense<1> : vector<1xi32>
+ %0 = arith.constant dense<1> : vector<1x128x128xi32>
+ // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi32>
+ %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi32> to vector<1xi32>
+ return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_odd
+func.func @fold_multid_reduction_i32_xor_odd() -> vector<1xi32> {
+ %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+ %0 = arith.constant dense<0xA0A> : vector<1x3xi32>
+ // CHECK: %{{.*}} = arith.constant dense<2805> : vector<1xi32>
+ %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x3xi32> to vector<1xi32>
+ return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_even
+func.func @fold_multid_reduction_i32_xor_even() -> vector<1xi32> {
+ %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+ %0 = arith.constant dense<0xA0A> : vector<1x4xi32>
+ // CHECK: %{{.*}} = arith.constant dense<255> : vector<1xi32>
+ %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x4xi32> to vector<1xi32>
+ return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i64_add
+func.func @fold_multid_reduction_i64_add() -> vector<1xi64> {
+ %cst_1 = arith.constant dense<1> : vector<1xi64>
+ %0 = arith.constant dense<1> : vector<1x128x128xi64>
+ // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi64>
+ %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi64> to vector<1xi64>
+ return %1 : vector<1xi64>
+}
\ No newline at end of file
>From 21bd7beedbcf7ac3aea5985559d491fa780986cf Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 13:25:29 +0000
Subject: [PATCH 02/16] fix fmt
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 94 ++++++++++++------------
1 file changed, 47 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a23d952c5760c5..9ee21aa3ad92b3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -476,29 +476,29 @@ OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
APFloat srcVal = src.getValue();
APFloat accVal = acc.getValue();
switch (kind) {
- case CombiningKind::ADD:{
- APFloat n = APFloat(srcVal.getSemantics());
- n.convertFromAPInt(APInt(64, times, true), true,
- APFloat::rmNearestTiesToEven);
- return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+ case CombiningKind::ADD: {
+ APFloat n = APFloat(srcVal.getSemantics());
+ n.convertFromAPInt(APInt(64, times, true), true,
+ APFloat::rmNearestTiesToEven);
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+ }
+ case CombiningKind::MUL: {
+ APFloat result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result = result * srcVal;
}
- case CombiningKind::MUL: {
- APFloat result = accVal;
- for (int i = 0; i < times; ++i) {
- result = result * srcVal;
- }
- return DenseElementsAttr::get(dstType, {result});
- }
- case CombiningKind::MINIMUMF:
- return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
- case CombiningKind::MAXIMUMF:
- return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
- case CombiningKind::MINNUMF:
- return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
- case CombiningKind::MAXNUMF:
- return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
- default:
- return {};
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+ case CombiningKind::MAXIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+ case CombiningKind::MINNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+ case CombiningKind::MAXNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+ default:
+ return {};
}
}
@@ -508,32 +508,32 @@ OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
APInt srcVal = src.getValue();
APInt accVal = acc.getValue();
switch (kind) {
- case CombiningKind::ADD:
- return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
- case CombiningKind::MUL: {
- APInt result = accVal;
- for (int i = 0; i < times; ++i) {
- result *= srcVal;
- }
- return DenseElementsAttr::get(dstType, {result});
+ case CombiningKind::ADD:
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+ case CombiningKind::MUL: {
+ APInt result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result *= srcVal;
}
- case CombiningKind::MINSI:
- return DenseElementsAttr::get(
- dstType, {accVal.slt(srcVal) ? accVal : srcVal});
- case CombiningKind::MAXSI:
- return DenseElementsAttr::get(
- dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
- case CombiningKind::MINUI:
- return DenseElementsAttr::get(
- dstType, {accVal.ult(srcVal) ? accVal : srcVal});
- case CombiningKind::MAXUI:
- return DenseElementsAttr::get(
- dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
- case CombiningKind::AND:
- return DenseElementsAttr::get(dstType, {accVal & srcVal});
- case CombiningKind::OR:
- return DenseElementsAttr::get(dstType, {accVal | srcVal});
- case CombiningKind::XOR:
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINSI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.slt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXSI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MINUI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.ult(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXUI:
+ return DenseElementsAttr::get(dstType,
+ {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::AND:
+ return DenseElementsAttr::get(dstType, {accVal & srcVal});
+ case CombiningKind::OR:
+ return DenseElementsAttr::get(dstType, {accVal | srcVal});
+ case CombiningKind::XOR:
return DenseElementsAttr::get(dstType,
{times & 0x1 ? accVal ^ srcVal : accVal});
default:
>From 4b2438c24ed53c245903b0f7ce242c043a1e2679 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 20:08:40 +0000
Subject: [PATCH 03/16] Apply comments.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 +--
mlir/test/Dialect/Vector/constant-fold.mlir | 50 ++++++++++-----------
2 files changed, 28 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9ee21aa3ad92b3..b095fac11cfbfc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -551,16 +551,16 @@ OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
return {};
if (!srcAttr.isSplat() || !accAttr.isSplat())
return {};
- auto reductionDims = getReductionDims();
+ ArrayRef<int64_t> reductionDims = getReductionDims();
auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
- auto srcDims = srcType.getShape();
+ ArrayRef<int64_t> srcDims = srcType.getShape();
int64_t times = 1;
for (auto dim : reductionDims) {
times *= srcDims[dim];
}
CombiningKind kind = getKind();
auto dstType = mlir::cast<ShapedType>(getDestType());
- auto eltype = dstType.getElementType();
+ Type eltype = dstType.getElementType();
if (mlir::dyn_cast_or_null<FloatType>(eltype)) {
return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
accAttr.getSplatValue<FloatAttr>(), times,
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 43c52b4b36ca53..a6aa06525f8ba3 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -12,56 +12,56 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
return %2 : vector<4x4xf16>
}
-// CHECK-LABEL: fold_multid_reduction_f32_add
-func.func @fold_multid_reduction_f32_add() -> vector<1xf32> {
- %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK-LABEL: fold_multi_reduction_f32_add
+func.func @fold_multi_reduction_f32_add() -> vector<1xf32> {
+ %acc = arith.constant dense<0.000000e+00> : vector<1xf32>
%0 = arith.constant dense<1.000000e+00> : vector<1x128x128xf32>
// CHECK: %{{.*}} = arith.constant dense<1.638400e+04> : vector<1xf32>
- %1 = vector.multi_reduction <add>, %0, %cst_0 [1, 2] : vector<1x128x128xf32> to vector<1xf32>
+ %1 = vector.multi_reduction <add>, %0, %acc [1, 2] : vector<1x128x128xf32> to vector<1xf32>
return %1 : vector<1xf32>
}
-// CHECK-LABEL: fold_multid_reduction_f32_mul
-func.func @fold_multid_reduction_f32_mul() -> vector<1xf32> {
- %cst_0 = arith.constant dense<1.000000e+00> : vector<1xf32>
+// CHECK-LABEL: fold_multi_reduction_f32_mul
+func.func @fold_multi_reduction_f32_mul() -> vector<1xf32> {
+ %acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
// CHECK: %{{.*}} = arith.constant dense<1.600000e+01> : vector<1xf32>
- %1 = vector.multi_reduction <mul>, %0, %cst_0 [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+ %1 = vector.multi_reduction <mul>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}
-// CHECK-LABEL: fold_multid_reduction_i32_add
-func.func @fold_multid_reduction_i32_add() -> vector<1xi32> {
- %cst_1 = arith.constant dense<1> : vector<1xi32>
+// CHECK-LABEL: fold_multi_reduction_i32_add
+func.func @fold_multi_reduction_i32_add() -> vector<1xi32> {
+ %acc = arith.constant dense<1> : vector<1xi32>
%0 = arith.constant dense<1> : vector<1x128x128xi32>
// CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi32>
- %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi32> to vector<1xi32>
+ %1 = vector.multi_reduction <add>, %0, %acc [1, 2] : vector<1x128x128xi32> to vector<1xi32>
return %1 : vector<1xi32>
}
-// CHECK-LABEL: fold_multid_reduction_i32_xor_odd
-func.func @fold_multid_reduction_i32_xor_odd() -> vector<1xi32> {
- %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+// CHECK-LABEL: fold_multi_reduction_i32_xor_odd_num_elements
+func.func @fold_multi_reduction_i32_xor_odd_num_elements() -> vector<1xi32> {
+ %acc = arith.constant dense<0xFF> : vector<1xi32>
%0 = arith.constant dense<0xA0A> : vector<1x3xi32>
// CHECK: %{{.*}} = arith.constant dense<2805> : vector<1xi32>
- %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x3xi32> to vector<1xi32>
+ %1 = vector.multi_reduction <xor>, %0, %acc [1] : vector<1x3xi32> to vector<1xi32>
return %1 : vector<1xi32>
}
-// CHECK-LABEL: fold_multid_reduction_i32_xor_even
-func.func @fold_multid_reduction_i32_xor_even() -> vector<1xi32> {
- %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+// CHECK-LABEL: fold_multi_reduction_i32_xor_even_num_elements
+func.func @fold_multi_reduction_i32_xor_even_num_elements() -> vector<1xi32> {
+ %acc = arith.constant dense<0xFF> : vector<1xi32>
%0 = arith.constant dense<0xA0A> : vector<1x4xi32>
// CHECK: %{{.*}} = arith.constant dense<255> : vector<1xi32>
- %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x4xi32> to vector<1xi32>
+ %1 = vector.multi_reduction <xor>, %0, %acc [1] : vector<1x4xi32> to vector<1xi32>
return %1 : vector<1xi32>
}
-// CHECK-LABEL: fold_multid_reduction_i64_add
-func.func @fold_multid_reduction_i64_add() -> vector<1xi64> {
- %cst_1 = arith.constant dense<1> : vector<1xi64>
+// CHECK-LABEL: fold_multi_reduction_i64_add
+func.func @fold_multi_reduction_i64_add() -> vector<1xi64> {
+ %acc = arith.constant dense<1> : vector<1xi64>
%0 = arith.constant dense<1> : vector<1x128x128xi64>
// CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi64>
- %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi64> to vector<1xi64>
+ %1 = vector.multi_reduction <add>, %0, %acc [1, 2] : vector<1x128x128xi64> to vector<1xi64>
return %1 : vector<1xi64>
-}
\ No newline at end of file
+}
>From 655381a54a272df7e7c89d743c6507f627ce7a5c Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 20:49:16 +0000
Subject: [PATCH 04/16] apply comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
mlir/test/Dialect/Vector/constant-fold.mlir | 27 +++++++++++++++++++++
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b095fac11cfbfc..ef8f5aaa461a8a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -46,7 +46,6 @@
#include "llvm/ADT/bit.h"
#include <cassert>
-#include <cmath>
#include <cstdint>
#include <numeric>
@@ -466,6 +465,8 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+/// Helper function to reduce a multi reduction where src and acc are splat
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation (add/max/etc.)
template <typename T>
OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
ShapedType dstType);
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index a6aa06525f8ba3..0b1edbdc32c21c 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -30,6 +30,33 @@ func.func @fold_multi_reduction_f32_mul() -> vector<1xf32> {
return %1 : vector<1xf32>
}
+// CHECK-LABEL: fold_multi_reduction_f32_maximumf
+func.func @fold_multi_reduction_f32_maximumf() -> vector<1xf32> {
+ %acc = arith.constant dense<1.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
+ // CHECK: %{{.*}} = arith.constant dense<2.600000e+01> : vector<1xf32>
+ %1 = vector.multi_reduction <maximumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multi_reduction_f32_minnumf
+func.func @fold_multi_reduction_f32_minnumf() -> vector<1xf32> {
+ %acc = arith.constant dense<1.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<0xFFFFFFFF> : vector<1x2x2xf32>
+ // CHECK: %{{.*}} = arith.constant dense<1.000000e+01> : vector<1xf32>
+ %1 = vector.multi_reduction <minnumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multi_reduction_f32_minimumf
+func.func @fold_multi_reduction_f32_minimumf() -> vector<1xf32> {
+ %acc = arith.constant dense<1.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<0xFFFFFFFF> : vector<1x2x2xf32>
+ // CHECK: %{{.*}} = arith.constant dense<0xFFFFFFFF> : vector<1xf32>
+ %1 = vector.multi_reduction <minimumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
// CHECK-LABEL: fold_multi_reduction_i32_add
func.func @fold_multi_reduction_i32_add() -> vector<1xi32> {
%acc = arith.constant dense<1> : vector<1xi32>
>From fdbab76dc5ff262e04d3a462668eae5941a86eeb Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 20:55:46 +0000
Subject: [PATCH 05/16] fix fmt.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ef8f5aaa461a8a..a9c0857881bb80 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -466,7 +466,8 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
}
/// Helper function to reduce a multi reduction where src and acc are splat
-/// Folds src @^times acc into OpFoldResult where @ is the reduction operation (add/max/etc.)
+/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
+/// (add/max/etc.)
template <typename T>
OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
ShapedType dstType);
>From 80b93dc5021593d1617924f0fc96b59f57c625d1 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 21:05:22 +0000
Subject: [PATCH 06/16] typo
---
mlir/test/Dialect/Vector/constant-fold.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 0b1edbdc32c21c..f1a89fc9a80413 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -34,7 +34,7 @@ func.func @fold_multi_reduction_f32_mul() -> vector<1xf32> {
func.func @fold_multi_reduction_f32_maximumf() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
- // CHECK: %{{.*}} = arith.constant dense<2.600000e+01> : vector<1xf32>
+ // CHECK: %{{.*}} = arith.constant dense<2.000000e+01> : vector<1xf32>
%1 = vector.multi_reduction <maximumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}
>From 91506445ca4eca3dff21a7bd8800cfedf3009914 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Fri, 10 Jan 2025 21:20:15 +0000
Subject: [PATCH 07/16] argh.
---
mlir/test/Dialect/Vector/constant-fold.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index f1a89fc9a80413..50b81fa8fa9fe4 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -34,7 +34,7 @@ func.func @fold_multi_reduction_f32_mul() -> vector<1xf32> {
func.func @fold_multi_reduction_f32_maximumf() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
- // CHECK: %{{.*}} = arith.constant dense<2.000000e+01> : vector<1xf32>
+ // CHECK: %{{.*}} = arith.constant dense<2.000000e+00> : vector<1xf32>
%1 = vector.multi_reduction <maximumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}
@@ -43,7 +43,7 @@ func.func @fold_multi_reduction_f32_maximumf() -> vector<1xf32> {
func.func @fold_multi_reduction_f32_minnumf() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<0xFFFFFFFF> : vector<1x2x2xf32>
- // CHECK: %{{.*}} = arith.constant dense<1.000000e+01> : vector<1xf32>
+ // CHECK: %{{.*}} = arith.constant dense<1.000000e+00> : vector<1xf32>
%1 = vector.multi_reduction <minnumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}
>From 5b5981c0971bacbac1781a911d437a9ff0162967 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Sun, 12 Jan 2025 16:43:19 +0000
Subject: [PATCH 08/16] Add empty lines. Change comment on aux function + make
it static. Combine check. Rename eltype.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 26 ++++++++++++++----------
1 file changed, 15 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a9c0857881bb80..9944922454d8de 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -465,12 +465,11 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
-/// Helper function to reduce a multi reduction where src and acc are splat
-/// Folds src @^times acc into OpFoldResult where @ is the reduction operation
-/// (add/max/etc.)
+/// Computes the result of reducing a constant vector where the accumulator
+/// value, `acc`, is also constant.
template <typename T>
-OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
- ShapedType dstType);
+static OpFoldResult foldSplatReduce(T src, T acc, int64_t times,
+ CombiningKind kind, ShapedType dstType);
template <>
OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
@@ -509,6 +508,7 @@ OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
CombiningKind kind, ShapedType dstType) {
APInt srcVal = src.getValue();
APInt accVal = acc.getValue();
+
switch (kind) {
case CombiningKind::ADD:
return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
@@ -547,32 +547,36 @@ OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
return getSource();
+
auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
- if (!srcAttr || !accAttr)
- return {};
- if (!srcAttr.isSplat() || !accAttr.isSplat())
+ if (!srcAttr || !accAttr || !srcAttr.isSplat() || !accAttr.isSplat())
return {};
+
ArrayRef<int64_t> reductionDims = getReductionDims();
auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
ArrayRef<int64_t> srcDims = srcType.getShape();
+
int64_t times = 1;
for (auto dim : reductionDims) {
times *= srcDims[dim];
}
+
CombiningKind kind = getKind();
auto dstType = mlir::cast<ShapedType>(getDestType());
- Type eltype = dstType.getElementType();
- if (mlir::dyn_cast_or_null<FloatType>(eltype)) {
+ Type dstEltType = dstType.getElementType();
+
+ if (mlir::dyn_cast_or_null<FloatType>(dstEltType)) {
return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
accAttr.getSplatValue<FloatAttr>(), times,
kind, dstType);
}
- if (mlir::dyn_cast_or_null<IntegerType>(eltype)) {
+ if (mlir::dyn_cast_or_null<IntegerType>(dstEltType)) {
return foldSplatReduce<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
accAttr.getSplatValue<IntegerAttr>(),
times, kind, dstType);
}
+
return {};
}
>From 0528985805b37345f1c5df4d2d18c2370bdb9818 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Sun, 12 Jan 2025 16:47:45 +0000
Subject: [PATCH 09/16] remove redundant includes.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ---
1 file changed, 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9944922454d8de..36ec16e9636385 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -21,13 +21,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
>From 1b3920aef66644ccc79bf0308433ae4384308f57 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Sun, 12 Jan 2025 16:49:10 +0000
Subject: [PATCH 10/16] foldSplatReduce->computeConstantReduction
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 36ec16e9636385..29d9967768a207 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -465,11 +465,11 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
/// Computes the result of reducing a constant vector where the accumulator
/// value, `acc`, is also constant.
template <typename T>
-static OpFoldResult foldSplatReduce(T src, T acc, int64_t times,
+static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
CombiningKind kind, ShapedType dstType);
template <>
-OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc, int64_t times,
CombiningKind kind, ShapedType dstType) {
APFloat srcVal = src.getValue();
APFloat accVal = acc.getValue();
@@ -501,7 +501,7 @@ OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
}
template <>
-OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc, int64_t times,
CombiningKind kind, ShapedType dstType) {
APInt srcVal = src.getValue();
APInt accVal = acc.getValue();
@@ -564,12 +564,12 @@ OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
Type dstEltType = dstType.getElementType();
if (mlir::dyn_cast_or_null<FloatType>(dstEltType)) {
- return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
+ return computeConstantReduction<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
accAttr.getSplatValue<FloatAttr>(), times,
kind, dstType);
}
if (mlir::dyn_cast_or_null<IntegerType>(dstEltType)) {
- return foldSplatReduce<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
+ return computeConstantReduction<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
accAttr.getSplatValue<IntegerAttr>(),
times, kind, dstType);
}
>From 33477c6ece8419c988d383ad082f5178b930638a Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Sun, 12 Jan 2025 16:53:11 +0000
Subject: [PATCH 11/16] .
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 +++++++++++++-----------
1 file changed, 14 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 29d9967768a207..d87be7eef8471a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -466,11 +466,13 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
/// value, `acc`, is also constant.
template <typename T>
static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
- CombiningKind kind, ShapedType dstType);
+ CombiningKind kind,
+ ShapedType dstType);
template <>
-OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc, int64_t times,
- CombiningKind kind, ShapedType dstType) {
+OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
+ int64_t times, CombiningKind kind,
+ ShapedType dstType) {
APFloat srcVal = src.getValue();
APFloat accVal = acc.getValue();
switch (kind) {
@@ -501,8 +503,9 @@ OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc, int64_t time
}
template <>
-OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc, int64_t times,
- CombiningKind kind, ShapedType dstType) {
+OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc,
+ int64_t times, CombiningKind kind,
+ ShapedType dstType) {
APInt srcVal = src.getValue();
APInt accVal = acc.getValue();
@@ -564,14 +567,14 @@ OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
Type dstEltType = dstType.getElementType();
if (mlir::dyn_cast_or_null<FloatType>(dstEltType)) {
- return computeConstantReduction<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
- accAttr.getSplatValue<FloatAttr>(), times,
- kind, dstType);
+ return computeConstantReduction<FloatAttr>(
+ srcAttr.getSplatValue<FloatAttr>(), accAttr.getSplatValue<FloatAttr>(),
+ times, kind, dstType);
}
if (mlir::dyn_cast_or_null<IntegerType>(dstEltType)) {
- return computeConstantReduction<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
- accAttr.getSplatValue<IntegerAttr>(),
- times, kind, dstType);
+ return computeConstantReduction<IntegerAttr>(
+ srcAttr.getSplatValue<IntegerAttr>(),
+ accAttr.getSplatValue<IntegerAttr>(), times, kind, dstType);
}
return {};
>From aaf75bbb440ce14c503d736bc121e958dcfddeb6 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Mon, 13 Jan 2025 12:24:33 +0000
Subject: [PATCH 12/16] add fast power
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 54 +++++++++++++++++++-----
1 file changed, 44 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d87be7eef8471a..92deb7d98ee3d3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -468,6 +468,48 @@ template <typename T>
static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
CombiningKind kind,
ShapedType dstType);
+template <typename T>
+static T power(const T &a, int64_t times);
+
+template <>
+APFloat power(const APFloat &a, int64_t exponent) {
+ assert(exponent >= 0 && "negative exponents not supported.");
+ if (exponent == 0) {
+ return APFloat::getOne(a.getSemantics());
+ }
+ APFloat acc = a;
+ int64_t remainingExponent = exponent;
+ while (remainingExponent > 1) {
+ if (remainingExponent % 2 == 0) {
+ acc = acc * acc;
+ remainingExponent /= 2;
+ } else {
+ acc = acc * a;
+ remainingExponent--;
+ }
+ }
+ return acc;
+};
+
+template <>
+APInt power(const APInt &a, int64_t exponent) {
+ assert(exponent >= 0 && "negative exponents not supported.");
+ if (exponent == 0) {
+ return APInt(a.getBitWidth(), 1);
+ }
+ APInt acc = a;
+ int64_t remainingExponent = exponent;
+ while (remainingExponent > 1) {
+ if (remainingExponent % 2 == 0) {
+ acc = acc * acc;
+ remainingExponent /= 2;
+ } else {
+ acc = acc * a;
+ remainingExponent--;
+ }
+ }
+ return acc;
+};
template <>
OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
@@ -483,11 +525,7 @@ OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
}
case CombiningKind::MUL: {
- APFloat result = accVal;
- for (int i = 0; i < times; ++i) {
- result = result * srcVal;
- }
- return DenseElementsAttr::get(dstType, {result});
+ return DenseElementsAttr::get(dstType, {accVal * power(srcVal, times)});
}
case CombiningKind::MINIMUMF:
return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
@@ -513,11 +551,7 @@ OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc,
case CombiningKind::ADD:
return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
case CombiningKind::MUL: {
- APInt result = accVal;
- for (int i = 0; i < times; ++i) {
- result *= srcVal;
- }
- return DenseElementsAttr::get(dstType, {result});
+ return DenseElementsAttr::get(dstType, {accVal * power(srcVal, times)});
}
case CombiningKind::MINSI:
return DenseElementsAttr::get(dstType,
>From 580b20c607d2bf9a50c245395f9423150b16dd45 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Mon, 13 Jan 2025 19:18:22 +0000
Subject: [PATCH 13/16] power->computePowerOf. Add TODO.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 92deb7d98ee3d3..25d3b4e4979318 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -468,11 +468,12 @@ template <typename T>
static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
CombiningKind kind,
ShapedType dstType);
+// TODO: move to APFloat, APInt headers.
template <typename T>
-static T power(const T &a, int64_t times);
+static T computePowerOf(const T &a, int64_t exponent);
template <>
-APFloat power(const APFloat &a, int64_t exponent) {
+APFloat computePowerOf(const APFloat &a, int64_t exponent) {
assert(exponent >= 0 && "negative exponents not supported.");
if (exponent == 0) {
return APFloat::getOne(a.getSemantics());
@@ -492,7 +493,7 @@ APFloat power(const APFloat &a, int64_t exponent) {
};
template <>
-APInt power(const APInt &a, int64_t exponent) {
+APInt computePowerOf(const APInt &a, int64_t exponent) {
assert(exponent >= 0 && "negative exponents not supported.");
if (exponent == 0) {
return APInt(a.getBitWidth(), 1);
@@ -525,7 +526,8 @@ OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
}
case CombiningKind::MUL: {
- return DenseElementsAttr::get(dstType, {accVal * power(srcVal, times)});
+ return DenseElementsAttr::get(dstType,
+ {accVal * computePowerOf(srcVal, times)});
}
case CombiningKind::MINIMUMF:
return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
>From b68c5f6149cc7bc9ef65af534b3f8bd4426cfb4a Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Mon, 13 Jan 2025 20:52:47 +0000
Subject: [PATCH 14/16] document times.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 25d3b4e4979318..5f81ffbde4453f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -463,7 +463,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
}
/// Computes the result of reducing a constant vector where the accumulator
-/// value, `acc`, is also constant.
+/// value, `acc`, is also constant. `times` is the number of times the operation is applied.
template <typename T>
static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
CombiningKind kind,
>From 71520f317bba810b65a405b3a7852701280372e7 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Mon, 13 Jan 2025 20:56:43 +0000
Subject: [PATCH 15/16] .
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5f81ffbde4453f..f5465bfe0241c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -463,7 +463,8 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
}
/// Computes the result of reducing a constant vector where the accumulator
-/// value, `acc`, is also constant. `times` is the number of times the operation is applied.
+/// value, `acc`, is also constant. `times` is the number of times the operation
+/// is applied.
template <typename T>
static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
CombiningKind kind,
>From 27801103b49e66984b037b11bb57131044fdc588 Mon Sep 17 00:00:00 2001
From: ImanHosseini <imanhosseini.17 at gmail.com>
Date: Mon, 13 Jan 2025 21:56:16 +0000
Subject: [PATCH 16/16] X-- to --X. de-templatize.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 52 ++++++++++--------------
1 file changed, 21 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f5465bfe0241c1..a6a08365dee467 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -462,19 +462,11 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+/// TODO: Move to APFloat/APInt.
/// Computes the result of reducing a constant vector where the accumulator
/// value, `acc`, is also constant. `times` is the number of times the operation
/// is applied.
-template <typename T>
-static OpFoldResult computeConstantReduction(T src, T acc, int64_t times,
- CombiningKind kind,
- ShapedType dstType);
-// TODO: move to APFloat, APInt headers.
-template <typename T>
-static T computePowerOf(const T &a, int64_t exponent);
-
-template <>
-APFloat computePowerOf(const APFloat &a, int64_t exponent) {
+static APFloat computePowerOf(const APFloat &a, int64_t exponent) {
assert(exponent >= 0 && "negative exponents not supported.");
if (exponent == 0) {
return APFloat::getOne(a.getSemantics());
@@ -487,14 +479,13 @@ APFloat computePowerOf(const APFloat &a, int64_t exponent) {
remainingExponent /= 2;
} else {
acc = acc * a;
- remainingExponent--;
+ --remainingExponent;
}
}
return acc;
};
-template <>
-APInt computePowerOf(const APInt &a, int64_t exponent) {
+static APInt computePowerOf(const APInt &a, int64_t exponent) {
assert(exponent >= 0 && "negative exponents not supported.");
if (exponent == 0) {
return APInt(a.getBitWidth(), 1);
@@ -513,10 +504,9 @@ APInt computePowerOf(const APInt &a, int64_t exponent) {
return acc;
};
-template <>
-OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
- int64_t times, CombiningKind kind,
- ShapedType dstType) {
+static OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
+ int64_t times, CombiningKind kind,
+ ShapedType dstType) {
APFloat srcVal = src.getValue();
APFloat accVal = acc.getValue();
switch (kind) {
@@ -543,10 +533,9 @@ OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
}
}
-template <>
-OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc,
- int64_t times, CombiningKind kind,
- ShapedType dstType) {
+static OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc,
+ int64_t times, CombiningKind kind,
+ ShapedType dstType) {
APInt srcVal = src.getValue();
APInt accVal = acc.getValue();
@@ -554,7 +543,8 @@ OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc,
case CombiningKind::ADD:
return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
case CombiningKind::MUL: {
- return DenseElementsAttr::get(dstType, {accVal * power(srcVal, times)});
+ return DenseElementsAttr::get(dstType,
+ {accVal * computePowerOf(srcVal, times)});
}
case CombiningKind::MINSI:
return DenseElementsAttr::get(dstType,
@@ -591,27 +581,27 @@ OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
return {};
ArrayRef<int64_t> reductionDims = getReductionDims();
- auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
+ auto srcType = cast<ShapedType>(getSourceVectorType());
ArrayRef<int64_t> srcDims = srcType.getShape();
int64_t times = 1;
- for (auto dim : reductionDims) {
+ for (int64_t dim : reductionDims) {
times *= srcDims[dim];
}
CombiningKind kind = getKind();
- auto dstType = mlir::cast<ShapedType>(getDestType());
+ auto dstType = cast<ShapedType>(getDestType());
Type dstEltType = dstType.getElementType();
if (mlir::dyn_cast_or_null<FloatType>(dstEltType)) {
- return computeConstantReduction<FloatAttr>(
- srcAttr.getSplatValue<FloatAttr>(), accAttr.getSplatValue<FloatAttr>(),
- times, kind, dstType);
+ return computeConstantReduction(srcAttr.getSplatValue<FloatAttr>(),
+ accAttr.getSplatValue<FloatAttr>(), times,
+ kind, dstType);
}
if (mlir::dyn_cast_or_null<IntegerType>(dstEltType)) {
- return computeConstantReduction<IntegerAttr>(
- srcAttr.getSplatValue<IntegerAttr>(),
- accAttr.getSplatValue<IntegerAttr>(), times, kind, dstType);
+ return computeConstantReduction(srcAttr.getSplatValue<IntegerAttr>(),
+ accAttr.getSplatValue<IntegerAttr>(), times,
+ kind, dstType);
}
return {};
More information about the Mlir-commits
mailing list