[Mlir-commits] [mlir] 8dae909 - [mlir][vector] Add constant folding for fp16 to fp32 bitcast
Lei Zhang
llvmlistbot at llvm.org
Fri Feb 5 06:16:04 PST 2021
Author: Lei Zhang
Date: 2021-02-05T09:12:50-05:00
New Revision: 8dae90997af7989b4afeb7586adacea40d9da002
URL: https://github.com/llvm/llvm-project/commit/8dae90997af7989b4afeb7586adacea40d9da002
DIFF: https://github.com/llvm/llvm-project/commit/8dae90997af7989b4afeb7586adacea40d9da002.diff
LOG: [mlir][vector] Add constant folding for fp16 to fp32 bitcast
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D96041
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 54e5e008e56f..f20b713e8e77 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/bit.h"
#include <numeric>
using namespace mlir;
@@ -2804,6 +2805,30 @@ OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
if (result().getType() == otherOp.source().getType())
return otherOp.source();
+ Attribute sourceConstant = operands.front();
+ if (!sourceConstant)
+ return {};
+
+ Type srcElemType = getSourceVectorType().getElementType();
+ Type dstElemType = getResultVectorType().getElementType();
+
+ if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
+ if (floatPack.isSplat()) {
+ auto splat = floatPack.getSplatValue<FloatAttr>();
+
+ // Casting fp16 into fp32.
+ if (srcElemType.isF16() && dstElemType.isF32()) {
+ uint32_t bits = static_cast<uint32_t>(
+ splat.getValue().bitcastToAPInt().getZExtValue());
+ // Duplicate the 16-bit pattern.
+ bits = (bits << 16) | (bits & 0xffff);
+ APInt intBits(32, bits);
+ APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
+ return DenseElementsAttr::get(getResultVectorType(), floatBits);
+ }
+ }
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0ff85da85bcb..9d810e17bcb5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -556,6 +556,20 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
return %0, %2 : vector<4x8xf32>, vector<2xi32>
}
+// CHECK-LABEL: func @bitcast_f16_to_f32
+// bit pattern: 0x00000000
+// CHECK: %[[CST0:.+]] = constant dense<0.000000e+00> : vector<4xf32>
+// bit pattern: 0x40004000
+// CHECK: %[[CST1:.+]] = constant dense<2.00390625> : vector<4xf32>
+// CHECK: return %[[CST0]], %[[CST1]]
+func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
+ %cst0 = constant dense<0.0> : vector<8xf16> // bit pattern: 0x0000
+ %cst1 = constant dense<2.0> : vector<8xf16> // bit pattern: 0x4000
+ %cast0 = vector.bitcast %cst0: vector<8xf16> to vector<4xf32>
+ %cast1 = vector.bitcast %cst1: vector<8xf16> to vector<4xf32>
+ return %cast0, %cast1: vector<4xf32>, vector<4xf32>
+}
+
// -----
// CHECK-LABEL: broadcast_folding1
More information about the Mlir-commits
mailing list