[Mlir-commits] [mlir] 5c990d6 - [mlir] Add support for bf16 to StandardToLLVM conversion
Diego Caballero
llvmlistbot at llvm.org
Thu Jun 4 14:39:29 PDT 2020
Author: Diego Caballero
Date: 2020-06-04T14:36:36-07:00
New Revision: 5c990d6994559225466cb256146f6440431b229e
URL: https://github.com/llvm/llvm-project/commit/5c990d6994559225466cb256146f6440431b229e
DIFF: https://github.com/llvm/llvm-project/commit/5c990d6994559225466cb256146f6440431b229e.diff
LOG: [mlir] Add support for bf16 to StandardToLLVM conversion
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D81127
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index ccf2185e4210..078cb1cfa4e5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -62,6 +62,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
llvm::Type *getUnderlyingType() const;
/// Utilities to identify types.
+ bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); }
bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
@@ -99,6 +100,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
/// Utilities used to generate floating point types.
static LLVMType getDoubleTy(LLVMDialect *dialect);
static LLVMType getFloatTy(LLVMDialect *dialect);
+ static LLVMType getBFloatTy(LLVMDialect *dialect);
static LLVMType getHalfTy(LLVMDialect *dialect);
static LLVMType getFP128Ty(LLVMDialect *dialect);
static LLVMType getX86_FP80Ty(LLVMDialect *dialect);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 4294e0024e79..5d3984d8ac90 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -201,9 +201,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
case mlir::StandardTypes::F16:
return LLVM::LLVMType::getHalfTy(llvmDialect);
case mlir::StandardTypes::BF16: {
- auto *mlirContext = llvmDialect->getContext();
- return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
- Type();
+ return LLVM::LLVMType::getBFloatTy(llvmDialect);
}
default:
llvm_unreachable("non-float type in convertFloatType");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2c6478ddd121..9fd8bfe6d26e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -939,8 +939,9 @@ static LogicalResult verify(DialectCastOp op) {
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
if (llvmType.isVectorTy())
llvmType = llvmType.getVectorElementType();
- if (llvmType.isIntegerTy() || llvmType.isHalfTy() ||
- llvmType.isFloatTy() || llvmType.isDoubleTy()) {
+ if (llvmType.isIntegerTy() || llvmType.isBFloatTy() ||
+ llvmType.isHalfTy() || llvmType.isFloatTy() ||
+ llvmType.isDoubleTy()) {
return success();
}
return op.emitOpError("type must be non-index integer types, float "
@@ -1500,7 +1501,8 @@ static LogicalResult verify(AtomicRMWOp op) {
} else if (op.bin_op() == AtomicBinOp::xchg) {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
!valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
- !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
+ !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
+ !valType.isDoubleTy())
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
@@ -1561,8 +1563,8 @@ static LogicalResult verify(AtomicCmpXchgOp op) {
"match type for all other operands");
if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
!valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
- !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() &&
- !valType.isDoubleTy())
+ !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
+ !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
return op.emitOpError("unexpected LLVM IR type");
if (op.success_ordering() < AtomicOrdering::monotonic ||
op.failure_ordering() < AtomicOrdering::monotonic)
@@ -1630,7 +1632,7 @@ struct LLVMDialectImpl {
/// A set of LLVMTypes that are cached on construction to avoid any lookups or
/// locking.
LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
- LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty;
+ LLVMType doubleTy, floatTy, bfloatTy, halfTy, fp128Ty, x86_fp80Ty;
LLVMType voidTy;
/// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
@@ -1665,6 +1667,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
/// Float Types.
impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
+ impl->bfloatTy = LLVMType::get(context, llvm::Type::getBFloatTy(llvmContext));
impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext));
impl->x86_fp80Ty =
@@ -1827,6 +1830,9 @@ LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
return dialect->impl->floatTy;
}
+LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
+ return dialect->impl->bfloatTy;
+}
LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
return dialect->impl->halfTy;
}
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index e2c3238b7bb0..ea21a6d9fea7 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1228,3 +1228,12 @@ func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 {
// CHECK-NEXT: llvm.return %[[ARG]]
return %1 : f16
}
+
+// -----
+
+// CHECK-LABEL: func @bfloat
+// CHECK-SAME: !llvm.bfloat) -> !llvm.bfloat
+func @bfloat(%arg0: bf16) -> bf16 {
+ return %arg0 : bf16
+}
+// CHECK-NEXT: return %{{.*}} : !llvm.bfloat
More information about the Mlir-commits
mailing list