[Mlir-commits] [mlir] e609417 - [mlir][Math] Add more constant folder for Math ops.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 21 19:23:34 PDT 2022
Author: jacquesguan
Date: 2022-03-22T10:23:15+08:00
New Revision: e609417cdc934c6101ca512b00edcf47d9aa4211
URL: https://github.com/llvm/llvm-project/commit/e609417cdc934c6101ca512b00edcf47d9aa4211
DIFF: https://github.com/llvm/llvm-project/commit/e609417cdc934c6101ca512b00edcf47d9aa4211.diff
LOG: [mlir][Math] Add more constant folder for Math ops.
This revision add constant folder for abs, copysign, ctlz, cttz and
ctpop.
Differential Revision: https://reviews.llvm.org/D122115
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/test/Dialect/Math/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index b0ccbc21439b0..221af3f6a5f2c 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -89,6 +89,7 @@ def Math_AbsOp : Math_FloatUnaryOp<"abs"> {
%x = math.abs %y : tensor<4x?xf8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -230,6 +231,7 @@ def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> {
%x = math.copysign %y, %z : tensor<4x?xf8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -320,6 +322,7 @@ def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
%x = math.ctlz %y : tensor<4x?xi8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -344,6 +347,7 @@ def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
%x = math.cttz %y : tensor<4x?xi8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -368,6 +372,7 @@ def Math_CtPopOp : Math_IntegerUnaryOp<"ctpop"> {
%x = math.ctpop %y : tensor<4x?xi8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 42f8334403c1e..28f42f814f6dc 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -20,6 +20,32 @@ using namespace mlir::math;
#define GET_OP_CLASSES
#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
+//===----------------------------------------------------------------------===//
+// AbsOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
+ auto constOperand = operands.front();
+ if (!constOperand)
+ return {};
+
+ auto attr = constOperand.dyn_cast<FloatAttr>();
+ if (!attr)
+ return {};
+
+ auto ft = getType().cast<FloatType>();
+
+ APFloat apf = attr.getValue();
+
+ if (ft.getWidth() == 64)
+ return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
+
+ if (ft.getWidth() == 32)
+ return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// CeilOp folder
//===----------------------------------------------------------------------===//
@@ -39,6 +65,81 @@ OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
return FloatAttr::get(getType(), sourceVal);
}
+//===----------------------------------------------------------------------===//
+// CopySignOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
+ auto ft = getType().dyn_cast<FloatType>();
+ if (!ft)
+ return {};
+
+ APFloat vals[2]{APFloat(ft.getFloatSemantics()),
+ APFloat(ft.getFloatSemantics())};
+ for (int i = 0; i < 2; ++i) {
+ if (!operands[i])
+ return {};
+
+ auto attr = operands[i].dyn_cast<FloatAttr>();
+ if (!attr)
+ return {};
+
+ vals[i] = attr.getValue();
+ }
+
+ vals[0].copySign(vals[1]);
+
+ return FloatAttr::get(getType(), vals[0]);
+}
+
+//===----------------------------------------------------------------------===//
+// CountLeadingZerosOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
+ auto constOperand = operands.front();
+ if (!constOperand)
+ return {};
+
+ auto attr = constOperand.dyn_cast<IntegerAttr>();
+ if (!attr)
+ return {};
+
+ return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
+}
+
+//===----------------------------------------------------------------------===//
+// CountTrailingZerosOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
+ auto constOperand = operands.front();
+ if (!constOperand)
+ return {};
+
+ auto attr = constOperand.dyn_cast<IntegerAttr>();
+ if (!attr)
+ return {};
+
+ return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
+}
+
+//===----------------------------------------------------------------------===//
+// CtPopOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
+ auto constOperand = operands.front();
+ if (!constOperand)
+ return {};
+
+ auto attr = constOperand.dyn_cast<IntegerAttr>();
+ if (!attr)
+ return {};
+
+ return IntegerAttr::get(getType(), attr.getValue().countPopulation());
+}
+
//===----------------------------------------------------------------------===//
// Log2Op folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 27a92908eaec4..45b13b455a2f0 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -91,3 +91,58 @@ func @sqrt_fold() -> f32 {
%r = math.sqrt %c : f32
return %r : f32
}
+
+// CHECK-LABEL: @abs_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @abs_fold() -> f32 {
+ %c = arith.constant -4.0 : f32
+ %r = math.abs %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @copysign_fold
+// CHECK: %[[cst:.+]] = arith.constant -4.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @copysign_fold() -> f32 {
+ %c1 = arith.constant 4.0 : f32
+ %c2 = arith.constant -9.0 : f32
+ %r = math.copysign %c1, %c2 : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @ctlz_fold1
+// CHECK: %[[cst:.+]] = arith.constant 31 : i32
+// CHECK: return %[[cst]]
+func @ctlz_fold1() -> i32 {
+ %c = arith.constant 1 : i32
+ %r = math.ctlz %c : i32
+ return %r : i32
+}
+
+// CHECK-LABEL: @ctlz_fold2
+// CHECK: %[[cst:.+]] = arith.constant 7 : i8
+// CHECK: return %[[cst]]
+func @ctlz_fold2() -> i8 {
+ %c = arith.constant 1 : i8
+ %r = math.ctlz %c : i8
+ return %r : i8
+}
+
+// CHECK-LABEL: @cttz_fold
+// CHECK: %[[cst:.+]] = arith.constant 8 : i32
+// CHECK: return %[[cst]]
+func @cttz_fold() -> i32 {
+ %c = arith.constant 256 : i32
+ %r = math.cttz %c : i32
+ return %r : i32
+}
+
+// CHECK-LABEL: @ctpop_fold
+// CHECK: %[[cst:.+]] = arith.constant 16 : i32
+// CHECK: return %[[cst]]
+func @ctpop_fold() -> i32 {
+ %c = arith.constant 0xFF0000FF : i32
+ %r = math.ctpop %c : i32
+ return %r : i32
+}
More information about the Mlir-commits
mailing list