[Mlir-commits] [mlir] 9d0b90e - [mlir][Math] Add TruncOp.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 8 19:01:40 PDT 2022


Author: jacquesguan
Date: 2022-09-09T10:01:28+08:00
New Revision: 9d0b90e9332d8224155554459cb07723f9880a04

URL: https://github.com/llvm/llvm-project/commit/9d0b90e9332d8224155554459cb07723f9880a04
DIFF: https://github.com/llvm/llvm-project/commit/9d0b90e9332d8224155554459cb07723f9880a04.diff

LOG: [mlir][Math] Add TruncOp.

This patch adds TruncOp for Math, it returns the operand rounded to the nearest integer not larger in magnitude than the operand. And this patch also adds the correspond llvm intrinsic op.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D133342

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
    mlir/test/Dialect/Math/canonicalize.mlir
    mlir/test/Dialect/Math/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 219efd57de357..8324408f0ef73 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -63,6 +63,7 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> {
 def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
 def LLVM_RoundEvenOp : LLVM_UnaryIntrinsicOp<"roundeven">;
 def LLVM_RoundOp : LLVM_UnaryIntrinsicOp<"round">;
+def LLVM_FTruncOp : LLVM_UnaryIntrinsicOp<"trunc">;
 def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
 def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
 def LLVM_PowIOp : LLVM_BinaryIntrinsicOp<"powi">;

diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 0e25d1b1554cf..2747c8e02faf8 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -804,6 +804,35 @@ def Math_RoundOp : Math_FloatUnaryOp<"round"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// TruncOp
+//===----------------------------------------------------------------------===//
+
+def Math_TruncOp : Math_FloatUnaryOp<"trunc"> {
+  let summary = "trunc of the specified value";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.trunc` ssa-use `:` type
+    ```
+
+    The `trunc` operation returns the operand rounded to the nearest integer
+    value in floating-point format. It takes one operand of floating point type
+    (i.e., scalar, tensor or vector) and produces one result of the same type.
+    The operation always rounds to the nearest integer not larger in magnitude
+    than the operand, regardless of the current rounding direction.
+
+    Example:
+
+    ```mlir
+    // Scalar trunc operation.
+    %a = math.trunc %b : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // FPowIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index a9ce30efdd084..b67a86f443b5c 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -47,6 +47,8 @@ using RoundOpLowering =
     VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+using FTruncOpLowering =
+    VectorConvertToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
 
 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
 template <typename MathOp, typename LLVMOp>
@@ -297,7 +299,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     RoundOpLowering,
     RsqrtOpLowering,
     SinOpLowering,
-    SqrtOpLowering
+    SqrtOpLowering,
+    FTruncOpLowering
   >(converter);
   // clang-format on
 }

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 561fdfcab5869..0e90defda26fd 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -155,19 +155,19 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
 void mlir::populateMathToLibmConversionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     llvm::Optional<PatternBenefit> log1pBenefit) {
-  patterns
-      .add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
-           VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
-           VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
-           VecOpToScalarOp<math::RoundEvenOp>, VecOpToScalarOp<math::RoundOp>,
-           VecOpToScalarOp<math::AtanOp>, VecOpToScalarOp<math::TanOp>>(
-          patterns.getContext(), benefit);
+  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+               VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
+               VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
+               VecOpToScalarOp<math::RoundEvenOp>,
+               VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
+               VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
+      patterns.getContext(), benefit);
   patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
                PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
                PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
                PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
-               PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>>(
-      patterns.getContext(), benefit);
+               PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>,
+               PromoteOpToF32<math::TruncOp>>(patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
                                                  "atan", benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
@@ -194,6 +194,8 @@ void mlir::populateMathToLibmConversionPatterns(
                                                   "floorf", "floor", benefit);
   patterns.add<ScalarOpToLibmCall<math::CeilOp>>(patterns.getContext(), "ceilf",
                                                  "ceil", benefit);
+  patterns.add<ScalarOpToLibmCall<math::TruncOp>>(patterns.getContext(),
+                                                  "truncf", "trunc", benefit);
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 7b3be5b4e8462..b5300745f26ab 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -457,6 +457,24 @@ OpFoldResult math::RoundOp::fold(ArrayRef<Attribute> operands) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// TruncOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::TruncOp::fold(ArrayRef<Attribute> operands) {
+  return constFoldUnaryOpConditional<FloatAttr>(
+      operands, [](const APFloat &a) -> Optional<APFloat> {
+        switch (a.getSizeInBits(a.getSemantics())) {
+        case 64:
+          return APFloat(trunc(a.convertToDouble()));
+        case 32:
+          return APFloat(truncf(a.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 /// Materialize an integer or floating point constant.
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,

diff  --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index b87ddbacab539..415ba1d9f001c 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -200,3 +200,13 @@ func.func @roundeven(%arg0 : f32) {
   %0 = math.roundeven %arg0 : f32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @trunc(
+// CHECK-SAME: f32
+func.func @trunc(%arg0 : f32) {
+  // CHECK: "llvm.intr.trunc"(%arg0) : (f32) -> f32
+  %0 = math.trunc %arg0 : f32
+  func.return
+}

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 2f00eaf53d91c..d911f8b1b8fbe 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -16,6 +16,8 @@
 // CHECK-DAG: @roundf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @roundeven(f64) -> f64 attributes {llvm.readnone}
 // CHECK-DAG: @roundevenf(f32) -> f32 attributes {llvm.readnone}
+// CHECK-DAG: @trunc(f64) -> f64 attributes {llvm.readnone}
+// CHECK-DAG: @truncf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @cos(f64) -> f64 attributes {llvm.readnone}
 // CHECK-DAG: @cosf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @sin(f64) -> f64 attributes {llvm.readnone}
@@ -227,6 +229,17 @@ func.func @roundeven_caller(%float: f32, %double: f64) -> (f32, f64) {
   return %float_result, %double_result : f32, f64
 }
 
+// CHECK-LABEL: func @trunc_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @truncf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.trunc %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @trunc(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.trunc %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
 
 // CHECK-LABEL: func @cos_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
@@ -300,6 +313,29 @@ func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
+// CHECK-LABEL:   func @trunc_vec_caller(
+// CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+func.func @trunc_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+  // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
+  // CHECK:           %[[OUT0_F32:.*]] = call @truncf(%[[IN0_F32]]) : (f32) -> f32
+  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+  // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
+  // CHECK:           %[[OUT1_F32:.*]] = call @truncf(%[[IN1_F32]]) : (f32) -> f32
+  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  %float_result = math.trunc %float : vector<2xf32>
+  // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
+  // CHECK:           %[[OUT0_F64:.*]] = call @trunc(%[[IN0_F64]]) : (f64) -> f64
+  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+  // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
+  // CHECK:           %[[OUT1_F64:.*]] = call @trunc(%[[IN1_F64]]) : (f64) -> f64
+  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  %double_result = math.trunc %double : vector<2xf64>
+  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
 
 // CHECK-LABEL: func @tan_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32

diff  --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index b0132d6ce0101..3625b1799585c 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -429,3 +429,21 @@ func.func @floor_fold2() -> f32 {
   %r = math.floor %c : f32
   return %r : f32
 }
+
+// CHECK-LABEL: @trunc_fold
+// CHECK-NEXT: %[[cst:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:   return %[[cst]]
+func.func @trunc_fold() -> f32 {
+  %c = arith.constant 1.1 : f32
+  %r = math.trunc %c : f32
+  return %r : f32
+}
+
+// CHECK-LABEL: @trunc_fold_vec
+// CHECK-NEXT: %[[cst:.+]] = arith.constant dense<[0.000000e+00, -0.000000e+00, 1.000000e+00, -1.000000e+00]> : vector<4xf32>
+// CHECK-NEXT:   return %[[cst]]
+func.func @trunc_fold_vec() -> (vector<4xf32>) {
+  %v = arith.constant dense<[0.5, -0.5, 1.5, -1.5]> : vector<4xf32>
+  %0 = math.trunc %v : vector<4xf32>
+  return %0 : vector<4xf32>
+}

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 1af096c429409..d984cbb66f8c2 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -257,3 +257,15 @@ func.func @ipowi(%i: i32, %v: vector<4xi32>, %t: tensor<4x4x?xi32>) {
   %2 = math.ipowi %t, %t : tensor<4x4x?xi32>
   return
 }
+
+// CHECK-LABEL: func @trunc(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @trunc(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.trunc %[[F]] : f32
+  %0 = math.trunc %f : f32
+  // CHECK: %{{.*}} = math.trunc %[[V]] : vector<4xf32>
+  %1 = math.trunc %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.trunc %[[T]] : tensor<4x4x?xf32>
+  %2 = math.trunc %t : tensor<4x4x?xf32>
+  return
+}


        


More information about the Mlir-commits mailing list