[Mlir-commits] [mlir] 0b8cb87 - [MLIR][STD] Add safe scalar constant propagation for FPTruncOp

Stella Stamenova llvmlistbot at llvm.org
Fri Aug 6 16:32:32 PDT 2021


Author: Max Kudryavtsev
Date: 2021-08-06T16:31:29-07:00
New Revision: 0b8cb87e0d6b092a9c94f9cc0f16e56e954eddfd

URL: https://github.com/llvm/llvm-project/commit/0b8cb87e0d6b092a9c94f9cc0f16e56e954eddfd
DIFF: https://github.com/llvm/llvm-project/commit/0b8cb87e0d6b092a9c94f9cc0f16e56e954eddfd.diff

LOG: [MLIR][STD] Add safe scalar constant propagation for FPTruncOp

Perform scalar constant propagation for FPTruncOp only if the resulting value can be represented without precision loss or rounding.

Example:
%cst = constant 1.000000e+00 : f32
%0 = fptrunc %cst : f32 to bf16
-->
%cst = constant 1.000000e+00 : bf16

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D107518

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d03d3bd64eb6..7f54dc5da2f8 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1220,6 +1220,8 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc"> {
     If the value cannot be exactly represented, it is rounded using the default
     rounding mode. When operating on vectors, casts elementwise.
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 98165f2b62b5..c0d18193b260 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1414,6 +1414,27 @@ bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
+/// Perform safe const propagation for fptrunc, i.e. only propagate
+/// if FP value can be represented without precision loss or rounding.
+OpFoldResult FPTruncOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "unary operation takes one operand");
+
+  auto constOperand = operands.front();
+  if (!constOperand || !constOperand.isa<FloatAttr>())
+    return {};
+
+  // Convert to target type via 'double'.
+  double sourceValue =
+      constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
+  auto targetAttr = FloatAttr::get(getType(), sourceValue);
+
+  // Propagate if constant's value does not change after truncation.
+  if (sourceValue == targetAttr.getValue().convertToDouble())
+    return targetAttr;
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index fa5125c5b0c1..2d9f0932cde9 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -196,11 +196,10 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
 }
 
 // CHECK-LABEL: @generalize_soft_plus_2d_f32
-//      CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64
+//      CHECK: %[[C1:.+]] = constant 1.000000e+00 : f32
 //      CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
-// CHECK-NEXT:   %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
 // CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
-// CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
+// CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1]], %[[EXP]] : f32
 // CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
 // CHECK-NEXT:   linalg.yield %[[LOG]] : f32
 // CHECK-NEXT: -> tensor<16x32xf32>

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 3a81cf60c709..478b0835385b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -80,6 +80,25 @@ func @truncConstant(%arg0: i8) -> i16 {
   return %tr : i16
 }
 
+// CHECK-LABEL: @truncFPConstant
+//       CHECK:   %[[cres:.+]] = constant 1.000000e+00 : bf16
+//       CHECK:   return %[[cres]]
+func @truncFPConstant() -> bf16 {
+  %cst = constant 1.000000e+00 : f32
+  %0 = fptrunc %cst : f32 to bf16
+  return %0 : bf16
+}
+
+// Test that cases with rounding are NOT propagated
+// CHECK-LABEL: @truncFPConstantRounding
+//       CHECK:   constant 1.444000e+25 : f32
+//       CHECK:   fptrunc
+func @truncFPConstantRounding() -> bf16 {
+  %cst = constant 1.444000e+25 : f32
+  %0 = fptrunc %cst : f32 to bf16
+  return %0 : bf16
+}
+
 // CHECK-LABEL: @tripleAddAdd
 //       CHECK:   %[[cres:.+]] = constant 59 : index 
 //       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index 


        


More information about the Mlir-commits mailing list