[Mlir-commits] [mlir] 8dae909 - [mlir][vector] Add constant folding for fp16 to fp32 bitcast

Lei Zhang llvmlistbot at llvm.org
Fri Feb 5 06:16:04 PST 2021


Author: Lei Zhang
Date: 2021-02-05T09:12:50-05:00
New Revision: 8dae90997af7989b4afeb7586adacea40d9da002

URL: https://github.com/llvm/llvm-project/commit/8dae90997af7989b4afeb7586adacea40d9da002
DIFF: https://github.com/llvm/llvm-project/commit/8dae90997af7989b4afeb7586adacea40d9da002.diff

LOG: [mlir][vector] Add constant folding for fp16 to fp32 bitcast

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96041

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 54e5e008e56f..f20b713e8e77 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/bit.h"
 #include <numeric>
 
 using namespace mlir;
@@ -2804,6 +2805,30 @@ OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
     if (result().getType() == otherOp.source().getType())
       return otherOp.source();
 
+  Attribute sourceConstant = operands.front();
+  if (!sourceConstant)
+    return {};
+
+  Type srcElemType = getSourceVectorType().getElementType();
+  Type dstElemType = getResultVectorType().getElementType();
+
+  if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
+    if (floatPack.isSplat()) {
+      auto splat = floatPack.getSplatValue<FloatAttr>();
+
+      // Casting fp16 into fp32.
+      if (srcElemType.isF16() && dstElemType.isF32()) {
+        uint32_t bits = static_cast<uint32_t>(
+            splat.getValue().bitcastToAPInt().getZExtValue());
+        // Duplicate the 16-bit pattern.
+        bits = (bits << 16) | (bits & 0xffff);
+        APInt intBits(32, bits);
+        APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
+        return DenseElementsAttr::get(getResultVectorType(), floatBits);
+      }
+    }
+  }
+
   return {};
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0ff85da85bcb..9d810e17bcb5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -556,6 +556,20 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
 
+// CHECK-LABEL: func @bitcast_f16_to_f32
+//              bit pattern: 0x00000000
+//       CHECK: %[[CST0:.+]] = constant dense<0.000000e+00> : vector<4xf32>
+//              bit pattern: 0x40004000
+//       CHECK: %[[CST1:.+]] = constant dense<2.00390625> : vector<4xf32>
+//       CHECK: return %[[CST0]], %[[CST1]]
+func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
+  %cst0 = constant dense<0.0> : vector<8xf16> // bit pattern: 0x0000
+  %cst1 = constant dense<2.0> : vector<8xf16> // bit pattern: 0x4000
+  %cast0 = vector.bitcast %cst0: vector<8xf16> to vector<4xf32>
+  %cast1 = vector.bitcast %cst1: vector<8xf16> to vector<4xf32>
+  return %cast0, %cast1: vector<4xf32>, vector<4xf32>
+}
+
 // -----
 
 // CHECK-LABEL: broadcast_folding1


        


More information about the Mlir-commits mailing list