[Mlir-commits] [mlir] e609417 - [mlir][Math] Add more constant folder for Math ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 21 19:23:34 PDT 2022


Author: jacquesguan
Date: 2022-03-22T10:23:15+08:00
New Revision: e609417cdc934c6101ca512b00edcf47d9aa4211

URL: https://github.com/llvm/llvm-project/commit/e609417cdc934c6101ca512b00edcf47d9aa4211
DIFF: https://github.com/llvm/llvm-project/commit/e609417cdc934c6101ca512b00edcf47d9aa4211.diff

LOG: [mlir][Math] Add more constant folder for Math ops.

This revision add constant folder for abs, copysign, ctlz, cttz and
ctpop.

Differential Revision: https://reviews.llvm.org/D122115

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/test/Dialect/Math/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index b0ccbc21439b0..221af3f6a5f2c 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -89,6 +89,7 @@ def Math_AbsOp : Math_FloatUnaryOp<"abs"> {
     %x = math.abs %y : tensor<4x?xf8>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -230,6 +231,7 @@ def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> {
     %x = math.copysign %y, %z : tensor<4x?xf8>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -320,6 +322,7 @@ def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
     %x = math.ctlz %y : tensor<4x?xi8>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -344,6 +347,7 @@ def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
     %x = math.cttz %y : tensor<4x?xi8>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -368,6 +372,7 @@ def Math_CtPopOp : Math_IntegerUnaryOp<"ctpop"> {
     %x = math.ctpop %y : tensor<4x?xi8>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 42f8334403c1e..28f42f814f6dc 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -20,6 +20,32 @@ using namespace mlir::math;
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// AbsOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand)
+    return {};
+
+  auto attr = constOperand.dyn_cast<FloatAttr>();
+  if (!attr)
+    return {};
+
+  auto ft = getType().cast<FloatType>();
+
+  APFloat apf = attr.getValue();
+
+  if (ft.getWidth() == 64)
+    return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
+
+  if (ft.getWidth() == 32)
+    return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CeilOp folder
 //===----------------------------------------------------------------------===//
@@ -39,6 +65,81 @@ OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
   return FloatAttr::get(getType(), sourceVal);
 }
 
+//===----------------------------------------------------------------------===//
+// CopySignOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
+  auto ft = getType().dyn_cast<FloatType>();
+  if (!ft)
+    return {};
+
+  APFloat vals[2]{APFloat(ft.getFloatSemantics()),
+                  APFloat(ft.getFloatSemantics())};
+  for (int i = 0; i < 2; ++i) {
+    if (!operands[i])
+      return {};
+
+    auto attr = operands[i].dyn_cast<FloatAttr>();
+    if (!attr)
+      return {};
+
+    vals[i] = attr.getValue();
+  }
+
+  vals[0].copySign(vals[1]);
+
+  return FloatAttr::get(getType(), vals[0]);
+}
+
+//===----------------------------------------------------------------------===//
+// CountLeadingZerosOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand)
+    return {};
+
+  auto attr = constOperand.dyn_cast<IntegerAttr>();
+  if (!attr)
+    return {};
+
+  return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
+}
+
+//===----------------------------------------------------------------------===//
+// CountTrailingZerosOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand)
+    return {};
+
+  auto attr = constOperand.dyn_cast<IntegerAttr>();
+  if (!attr)
+    return {};
+
+  return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
+}
+
+//===----------------------------------------------------------------------===//
+// CtPopOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand)
+    return {};
+
+  auto attr = constOperand.dyn_cast<IntegerAttr>();
+  if (!attr)
+    return {};
+
+  return IntegerAttr::get(getType(), attr.getValue().countPopulation());
+}
+
 //===----------------------------------------------------------------------===//
 // Log2Op folder
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 27a92908eaec4..45b13b455a2f0 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -91,3 +91,58 @@ func @sqrt_fold() -> f32 {
   %r = math.sqrt %c : f32
   return %r : f32
 }
+
+// CHECK-LABEL: @abs_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @abs_fold() -> f32 {
+  %c = arith.constant -4.0 : f32
+  %r = math.abs %c : f32
+  return %r : f32
+}
+
+// CHECK-LABEL: @copysign_fold
+// CHECK: %[[cst:.+]] = arith.constant -4.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @copysign_fold() -> f32 {
+  %c1 = arith.constant 4.0 : f32
+  %c2 = arith.constant -9.0 : f32
+  %r = math.copysign %c1, %c2 : f32
+  return %r : f32
+}
+
+// CHECK-LABEL: @ctlz_fold1
+// CHECK: %[[cst:.+]] = arith.constant 31 : i32
+// CHECK: return %[[cst]]
+func @ctlz_fold1() -> i32 {
+  %c = arith.constant 1 : i32
+  %r = math.ctlz %c : i32
+  return %r : i32
+}
+
+// CHECK-LABEL: @ctlz_fold2
+// CHECK: %[[cst:.+]] = arith.constant 7 : i8
+// CHECK: return %[[cst]]
+func @ctlz_fold2() -> i8 {
+  %c = arith.constant 1 : i8
+  %r = math.ctlz %c : i8
+  return %r : i8
+}
+
+// CHECK-LABEL: @cttz_fold
+// CHECK: %[[cst:.+]] = arith.constant 8 : i32
+// CHECK: return %[[cst]]
+func @cttz_fold() -> i32 {
+  %c = arith.constant 256 : i32
+  %r = math.cttz %c : i32
+  return %r : i32
+}
+
+// CHECK-LABEL: @ctpop_fold
+// CHECK: %[[cst:.+]] = arith.constant 16 : i32
+// CHECK: return %[[cst]]
+func @ctpop_fold() -> i32 {
+  %c = arith.constant 0xFF0000FF : i32
+  %r = math.ctpop %c : i32
+  return %r : i32
+}


        


More information about the Mlir-commits mailing list