[Mlir-commits] [mlir] 5c990d6 - [mlir] Add support for bf16 to StandardToLLVM conversion

Diego Caballero llvmlistbot at llvm.org
Thu Jun 4 14:39:29 PDT 2020


Author: Diego Caballero
Date: 2020-06-04T14:36:36-07:00
New Revision: 5c990d6994559225466cb256146f6440431b229e

URL: https://github.com/llvm/llvm-project/commit/5c990d6994559225466cb256146f6440431b229e
DIFF: https://github.com/llvm/llvm-project/commit/5c990d6994559225466cb256146f6440431b229e.diff

LOG: [mlir] Add support for bf16 to StandardToLLVM conversion

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D81127

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index ccf2185e4210..078cb1cfa4e5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -62,6 +62,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
   llvm::Type *getUnderlyingType() const;
 
   /// Utilities to identify types.
+  bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); }
   bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
   bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
   bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
@@ -99,6 +100,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
   /// Utilities used to generate floating point types.
   static LLVMType getDoubleTy(LLVMDialect *dialect);
   static LLVMType getFloatTy(LLVMDialect *dialect);
+  static LLVMType getBFloatTy(LLVMDialect *dialect);
   static LLVMType getHalfTy(LLVMDialect *dialect);
   static LLVMType getFP128Ty(LLVMDialect *dialect);
   static LLVMType getX86_FP80Ty(LLVMDialect *dialect);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 4294e0024e79..5d3984d8ac90 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -201,9 +201,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
   case mlir::StandardTypes::F16:
     return LLVM::LLVMType::getHalfTy(llvmDialect);
   case mlir::StandardTypes::BF16: {
-    auto *mlirContext = llvmDialect->getContext();
-    return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
-           Type();
+    return LLVM::LLVMType::getBFloatTy(llvmDialect);
   }
   default:
     llvm_unreachable("non-float type in convertFloatType");

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2c6478ddd121..9fd8bfe6d26e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -939,8 +939,9 @@ static LogicalResult verify(DialectCastOp op) {
     if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
       if (llvmType.isVectorTy())
         llvmType = llvmType.getVectorElementType();
-      if (llvmType.isIntegerTy() || llvmType.isHalfTy() ||
-          llvmType.isFloatTy() || llvmType.isDoubleTy()) {
+      if (llvmType.isIntegerTy() || llvmType.isBFloatTy() ||
+          llvmType.isHalfTy() || llvmType.isFloatTy() ||
+          llvmType.isDoubleTy()) {
         return success();
       }
       return op.emitOpError("type must be non-index integer types, float "
@@ -1500,7 +1501,8 @@ static LogicalResult verify(AtomicRMWOp op) {
   } else if (op.bin_op() == AtomicBinOp::xchg) {
     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
         !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
-        !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
+        !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
+        !valType.isDoubleTy())
       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
   } else {
     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
@@ -1561,8 +1563,8 @@ static LogicalResult verify(AtomicCmpXchgOp op) {
                           "match type for all other operands");
   if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
       !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
-      !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() &&
-      !valType.isDoubleTy())
+      !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
+      !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
     return op.emitOpError("unexpected LLVM IR type");
   if (op.success_ordering() < AtomicOrdering::monotonic ||
       op.failure_ordering() < AtomicOrdering::monotonic)
@@ -1630,7 +1632,7 @@ struct LLVMDialectImpl {
   /// A set of LLVMTypes that are cached on construction to avoid any lookups or
   /// locking.
   LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
-  LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty;
+  LLVMType doubleTy, floatTy, bfloatTy, halfTy, fp128Ty, x86_fp80Ty;
   LLVMType voidTy;
 
   /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
@@ -1665,6 +1667,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
   /// Float Types.
   impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
   impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
+  impl->bfloatTy = LLVMType::get(context, llvm::Type::getBFloatTy(llvmContext));
   impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
   impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext));
   impl->x86_fp80Ty =
@@ -1827,6 +1830,9 @@ LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
 LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
   return dialect->impl->floatTy;
 }
+LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
+  return dialect->impl->bfloatTy;
+}
 LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
   return dialect->impl->halfTy;
 }

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index e2c3238b7bb0..ea21a6d9fea7 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1228,3 +1228,12 @@ func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 {
   // CHECK-NEXT: llvm.return %[[ARG]]
   return %1 : f16
 }
+
+// -----
+
+// CHECK-LABEL: func @bfloat
+// CHECK-SAME: !llvm.bfloat) -> !llvm.bfloat
+func @bfloat(%arg0: bf16) -> bf16 {
+  return %arg0 : bf16
+}
+// CHECK-NEXT: return %{{.*}} : !llvm.bfloat


        


More information about the Mlir-commits mailing list