[Mlir-commits] [mlir] 0b8cb87 - [MLIR][STD] Add safe scalar constant propagation for FPTruncOp
Stella Stamenova
llvmlistbot at llvm.org
Fri Aug 6 16:32:32 PDT 2021
Author: Max Kudryavtsev
Date: 2021-08-06T16:31:29-07:00
New Revision: 0b8cb87e0d6b092a9c94f9cc0f16e56e954eddfd
URL: https://github.com/llvm/llvm-project/commit/0b8cb87e0d6b092a9c94f9cc0f16e56e954eddfd
DIFF: https://github.com/llvm/llvm-project/commit/0b8cb87e0d6b092a9c94f9cc0f16e56e954eddfd.diff
LOG: [MLIR][STD] Add safe scalar constant propagation for FPTruncOp
Perform scalar constant propagation for FPTruncOp only if the resulting value can be represented without precision loss or rounding.
Example:
%cst = constant 1.000000e+00 : f32
%0 = fptrunc %cst : f32 to bf16
-->
%cst = constant 1.000000e+00 : bf16
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D107518
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d03d3bd64eb6..7f54dc5da2f8 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1220,6 +1220,8 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc"> {
If the value cannot be exactly represented, it is rounded using the default
rounding mode. When operating on vectors, casts elementwise.
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 98165f2b62b5..c0d18193b260 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1414,6 +1414,27 @@ bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}
+/// Perform safe const propagation for fptrunc, i.e. only propagate
+/// if FP value can be represented without precision loss or rounding.
+OpFoldResult FPTruncOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1 && "unary operation takes one operand");
+
+ auto constOperand = operands.front();
+ if (!constOperand || !constOperand.isa<FloatAttr>())
+ return {};
+
+ // Convert to target type via 'double'.
+ double sourceValue =
+ constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
+ auto targetAttr = FloatAttr::get(getType(), sourceValue);
+
+ // Propagate if constant's value does not change after truncation.
+ if (sourceValue == targetAttr.getValue().convertToDouble())
+ return targetAttr;
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index fa5125c5b0c1..2d9f0932cde9 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -196,11 +196,10 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
}
// CHECK-LABEL: @generalize_soft_plus_2d_f32
-// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64
+// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f32
// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
-// CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
-// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
+// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1]], %[[EXP]] : f32
// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
// CHECK-NEXT: linalg.yield %[[LOG]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 3a81cf60c709..478b0835385b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -80,6 +80,25 @@ func @truncConstant(%arg0: i8) -> i16 {
return %tr : i16
}
+// CHECK-LABEL: @truncFPConstant
+// CHECK: %[[cres:.+]] = constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func @truncFPConstant() -> bf16 {
+ %cst = constant 1.000000e+00 : f32
+ %0 = fptrunc %cst : f32 to bf16
+ return %0 : bf16
+}
+
+// Test that cases with rounding are NOT propagated
+// CHECK-LABEL: @truncFPConstantRounding
+// CHECK: constant 1.444000e+25 : f32
+// CHECK: fptrunc
+func @truncFPConstantRounding() -> bf16 {
+ %cst = constant 1.444000e+25 : f32
+ %0 = fptrunc %cst : f32 to bf16
+ return %0 : bf16
+}
+
// CHECK-LABEL: @tripleAddAdd
// CHECK: %[[cres:.+]] = constant 59 : index
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
More information about the Mlir-commits
mailing list