[Mlir-commits] [mlir] 1773ddd - [MLIR][Math] Enable constant folding of ops
William S. Moses
llvmlistbot at llvm.org
Wed Jan 12 09:19:48 PST 2022
Author: William S. Moses
Date: 2022-01-12T12:19:29-05:00
New Revision: 1773dddadf5de5ad0477394ac2f308eff38c9978
URL: https://github.com/llvm/llvm-project/commit/1773dddadf5de5ad0477394ac2f308eff38c9978
DIFF: https://github.com/llvm/llvm-project/commit/1773dddadf5de5ad0477394ac2f308eff38c9978.diff
LOG: [MLIR][Math] Enable constant folding of ops
Enable constant folding of ops within the math dialect, and introduce constant folders for ceil and log2
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D117085
Added:
mlir/test/Dialect/Math/canonicalize.mlir
Modified:
mlir/include/mlir/Dialect/Math/IR/MathBase.td
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Dialect/Math/IR/CMakeLists.txt
mlir/lib/Dialect/Math/IR/MathOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
index b9869d157a61..8df9e2ce6d13 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
@@ -15,6 +15,7 @@ def Math_Dialect : Dialect {
The math dialect is intended to hold mathematical operations on integer and
floating type beyond simple arithmetics.
}];
+ let hasConstantMaterializer = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
#endif // MATH_BASE
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index bef60175e4b6..b0f5d97457ec 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -195,6 +195,7 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
%x = math.ceil %y : tensor<4x?xf8>
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -649,6 +650,7 @@ def Math_Log2Op : Math_FloatUnaryOp<"log2"> {
%y = math.log2 %x : f64
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt
index 68acba99bc2f..45bb9c310cbf 100644
--- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMath
MLIRMathOpsIncGen
LINK_LIBS PUBLIC
+ MLIRArithmetic
MLIRDialect
MLIRIR
)
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index d152de0a54d4..16dd2cb17e03 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -6,7 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/Builders.h"
using namespace mlir;
using namespace mlir::math;
@@ -17,3 +19,58 @@ using namespace mlir::math;
#define GET_OP_CLASSES
#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// CeilOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
+ auto constOperand = operands.front();
+ if (!constOperand)
+ return {};
+
+ auto attr = constOperand.dyn_cast<FloatAttr>();
+ if (!attr)
+ return {};
+
+ APFloat sourceVal = attr.getValue();
+ sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
+
+ return FloatAttr::get(getType(), sourceVal);
+}
+
+//===----------------------------------------------------------------------===//
+// Log2Op folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
+ auto constOperand = operands.front();
+ if (!constOperand)
+ return {};
+
+ auto attr = constOperand.dyn_cast<FloatAttr>();
+ if (!attr)
+ return {};
+
+ auto FT = getType().cast<FloatType>();
+
+ APFloat APF = attr.getValue();
+
+ if (APF.isNegative())
+ return {};
+
+ if (FT.getWidth() == 64)
+ return FloatAttr::get(getType(), log2(APF.convertToDouble()));
+
+ if (FT.getWidth() == 32)
+ return FloatAttr::get(getType(), log2f(APF.convertToDouble()));
+
+ return {};
+}
+
+/// Materialize an integer or floating point constant.
+Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<arith::ConstantOp>(loc, value, type);
+}
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
new file mode 100644
index 000000000000..f62f0cf0cde5
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-LABEL: @ceil_fold
+// CHECK: %[[cst:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @ceil_fold() -> f32 {
+ %c = arith.constant 0.3 : f32
+ %r = math.ceil %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @ceil_fold2
+// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @ceil_fold2() -> f32 {
+ %c = arith.constant 2.0 : f32
+ %r = math.ceil %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @log2_fold
+// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
+ // CHECK: return %[[cst]]
+func @log2_fold() -> f32 {
+ %c = arith.constant 4.0 : f32
+ %r = math.log2 %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @log2_fold2
+// CHECK: %[[cst:.+]] = arith.constant 0xFF800000 : f32
+ // CHECK: return %[[cst]]
+func @log2_fold2() -> f32 {
+ %c = arith.constant 0.0 : f32
+ %r = math.log2 %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @log2_nofold2
+// CHECK: %[[cst:.+]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[res:.+]] = math.log2 %[[cst]] : f32
+ // CHECK: return %[[res]]
+func @log2_nofold2() -> f32 {
+ %c = arith.constant -1.0 : f32
+ %r = math.log2 %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @log2_fold_64
+// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64
+ // CHECK: return %[[cst]]
+func @log2_fold_64() -> f64 {
+ %c = arith.constant 4.0 : f64
+ %r = math.log2 %c : f64
+ return %r : f64
+}
+
+// CHECK-LABEL: @log2_fold2_64
+// CHECK: %[[cst:.+]] = arith.constant 0xFFF0000000000000 : f64
+ // CHECK: return %[[cst]]
+func @log2_fold2_64() -> f64 {
+ %c = arith.constant 0.0 : f64
+ %r = math.log2 %c : f64
+ return %r : f64
+}
+
+// CHECK-LABEL: @log2_nofold2_64
+// CHECK: %[[cst:.+]] = arith.constant -1.000000e+00 : f64
+// CHECK: %[[res:.+]] = math.log2 %[[cst]] : f64
+ // CHECK: return %[[res]]
+func @log2_nofold2_64() -> f64 {
+ %c = arith.constant -1.0 : f64
+ %r = math.log2 %c : f64
+ return %r : f64
+}
More information about the Mlir-commits
mailing list