[Mlir-commits] [mlir] ca6baf1 - [MLIR][std] Introduce bitcast operation

Geoffrey Martin-Noble llvmlistbot at llvm.org
Fri Aug 6 08:48:02 PDT 2021


Author: Geoffrey Martin-Noble
Date: 2021-08-06T08:47:51-07:00
New Revision: ca6baf1e1da2ef1dcfcae837242ef0024a75f400

URL: https://github.com/llvm/llvm-project/commit/ca6baf1e1da2ef1dcfcae837242ef0024a75f400
DIFF: https://github.com/llvm/llvm-project/commit/ca6baf1e1da2ef1dcfcae837242ef0024a75f400.diff

LOG: [MLIR][std] Introduce bitcast operation

This patch introduces a bitcast operation to the standard dialect.
RFC: https://llvm.discourse.group/t/rfc-introduce-a-bitcast-op/3774

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D105376

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 9aa7c5cfbb314..d03d3bd64eb6d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -456,6 +456,32 @@ def AtomicYieldOp : Std_Op<"atomic_yield", [
   let assemblyFormat = "$result attr-dict `:` type($result)";
 }
 
+//===----------------------------------------------------------------------===//
+// BitcastOp
+//===----------------------------------------------------------------------===//
+
+def BitcastOp : ArithmeticCastOp<"bitcast"> {
+  let summary = "bitcast between values of equal bit width";
+  let description = [{
+    Bitcast an integer or floating point value to an integer or floating point
+    value of equal bit width. When operating on vectors, casts elementwise.
+
+    Note that this implements a logical bitcast independent of target
+    endianness. This allows constant folding without target information and is
+    consitent with the bitcast constant folders in LLVM (see
+    https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168)
+    For targets where the source and target type have the same endianness (which
+    is the standard), this cast will also change no bits at runtime, but it may
+    still require an operation, for example if the machine has 
diff erent
+    floating point and integer register files. For targets that have a 
diff erent
+    endianness for the source and target types (e.g. float is big-endian and
+    integer is little-endian) a proper lowering would add operations to swap the
+    order of words in addition to the bitcast.
+  }];
+  let hasFolder = 1;
+}
+
+
 //===----------------------------------------------------------------------===//
 // BranchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index e9ec915cc8717..98165f2b62b5b 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -482,6 +483,62 @@ static LogicalResult verify(AtomicYieldOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// BitcastOp
+//===----------------------------------------------------------------------===//
+
+bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  assert(inputs.size() == 1 && outputs.size() == 1 &&
+         "bitcast op expects one operand and result");
+  Type a = inputs.front(), b = outputs.front();
+  if (a.isSignlessIntOrFloat() && b.isSignlessIntOrFloat())
+    return a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
+}
+
+OpFoldResult BitcastOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "bitcastop expects 1 operand");
+
+  // Bitcast of bitcast
+  auto *sourceOp = getOperand().getDefiningOp();
+  if (auto sourceBitcast = dyn_cast_or_null<BitcastOp>(sourceOp)) {
+    setOperand(sourceBitcast.getOperand());
+    return getResult();
+  }
+
+  auto operand = operands[0];
+  if (!operand)
+    return {};
+
+  Type resType = getResult().getType();
+
+  if (auto denseAttr = operand.dyn_cast<DenseFPElementsAttr>()) {
+    Type elType = getElementTypeOrSelf(resType);
+    return denseAttr.mapValues(
+        elType, [](const APFloat &f) { return f.bitcastToAPInt(); });
+  }
+  if (auto denseAttr = operand.dyn_cast<DenseIntElementsAttr>()) {
+    Type elType = getElementTypeOrSelf(resType);
+    // mapValues does its own bitcast to the target type.
+    return denseAttr.mapValues(elType, [](const APInt &i) { return i; });
+  }
+
+  APInt bits;
+  if (auto floatAttr = operand.dyn_cast<FloatAttr>())
+    bits = floatAttr.getValue().bitcastToAPInt();
+  else if (auto intAttr = operand.dyn_cast<IntegerAttr>())
+    bits = intAttr.getValue();
+  else
+    return {};
+
+  if (resType.isa<IntegerType>())
+    return IntegerAttr::get(resType, bits);
+  if (auto resFloatType = resType.dyn_cast<FloatType>())
+    return FloatAttr::get(resType,
+                          APFloat(resFloatType.getFloatSemantics(), bits));
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // BranchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index d2ef830537f9f..3a81cf60c7096 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -331,3 +331,102 @@ func @selToNot(%arg0: i1) -> i1 {
   %res = select %arg0, %false, %true : i1
   return %res : i1
 }
+
+// -----
+
+// CHECK-LABEL: @bitcastSameType(
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
+func @bitcastSameType(%arg : f32) -> f32 {
+  // CHECK: return %[[ARG]]
+  %res = bitcast %arg : f32 to f32
+  return %res : f32
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastConstantFPtoI(
+func @bitcastConstantFPtoI() -> i32 {
+  // CHECK: %[[C0:.+]] = constant 0 : i32
+  // CHECK: return %[[C0]]
+  %c0 = constant 0.0 : f32
+  %res = bitcast %c0 : f32 to i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastConstantItoFP(
+func @bitcastConstantItoFP() -> f32 {
+  // CHECK: %[[C0:.+]] = constant 0.0{{.*}} : f32
+  // CHECK: return %[[C0]]
+  %c0 = constant 0 : i32
+  %res = bitcast %c0 : i32 to f32
+  return %res : f32
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastConstantFPtoFP(
+func @bitcastConstantFPtoFP() -> f16 {
+  // CHECK: %[[C0:.+]] = constant 0.0{{.*}} : f16
+  // CHECK: return %[[C0]]
+  %c0 = constant 0.0 : bf16
+  %res = bitcast %c0 : bf16 to f16
+  return %res : f16
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastConstantVecFPtoI(
+func @bitcastConstantVecFPtoI() -> vector<3xf32> {
+  // CHECK: %[[C0:.+]] = constant dense<0.0{{.*}}> : vector<3xf32>
+  // CHECK: return %[[C0]]
+  %c0 = constant dense<0> : vector<3xi32>
+  %res = bitcast %c0 : vector<3xi32> to vector<3xf32>
+  return %res : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastConstantVecItoFP(
+func @bitcastConstantVecItoFP() -> vector<3xi32> {
+  // CHECK: %[[C0:.+]] = constant dense<0> : vector<3xi32>
+  // CHECK: return %[[C0]]
+  %c0 = constant dense<0.0> : vector<3xf32>
+  %res = bitcast %c0 : vector<3xf32> to vector<3xi32>
+  return %res : vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastConstantVecFPtoFP(
+func @bitcastConstantVecFPtoFP() -> vector<3xbf16> {
+  // CHECK: %[[C0:.+]] = constant dense<0.0{{.*}}> : vector<3xbf16>
+  // CHECK: return %[[C0]]
+  %c0 = constant dense<0.0> : vector<3xf16>
+  %res = bitcast %c0 : vector<3xf16> to vector<3xbf16>
+  return %res : vector<3xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastBackAndForth(
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
+func @bitcastBackAndForth(%arg : i32) -> i32 {
+  // CHECK: return %[[ARG]]
+  %f = bitcast %arg : i32 to f32
+  %res = bitcast %f : f32 to i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: @bitcastOfBitcast(
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
+func @bitcastOfBitcast(%arg : i16) -> i16 {
+  // CHECK: return %[[ARG]]
+  %f = bitcast %arg : i16 to f16
+  %bf = bitcast %f : f16 to bf16
+  %res = bitcast %bf : bf16 to i16
+  return %res : i16
+}

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index c45455b0d9695..c536e85a49b70 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -85,3 +85,11 @@ func @call() {
   %0:2 = call @return_i32_f32() : () -> (f32, i32)
   return
 }
+
+// -----
+
+func @bitcast_
diff erent_bit_widths(%arg : f16) -> f32 {
+  // expected-error at +1 {{are cast incompatible}}
+  %res = bitcast %arg : f16 to f32
+  return %res : f32
+}

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index a39cb742b431c..0f73a8d9358a4 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -80,3 +80,9 @@ func @constant_complex_f64() -> complex<f64> {
   %result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
   return %result : complex<f64>
 }
+
+// CHECK-LABEL: func @bitcast(
+func @bitcast(%arg : f32) -> i32 {
+  %res = bitcast %arg : f32 to i32
+  return %res : i32
+}


        


More information about the Mlir-commits mailing list