[Mlir-commits] [mlir] [mlir][spirv] Convert `bf16` to `spirv` as a `i16` (PR #76114)
Rob Suderman
llvmlistbot at llvm.org
Wed Dec 20 18:13:14 PST 2023
https://github.com/rsuderman updated https://github.com/llvm/llvm-project/pull/76114
>From 1bddf3e07bb95d6a71da56fb66a44e678f45cc74 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Wed, 20 Dec 2023 17:18:13 -0800
Subject: [PATCH 1/2] [mlir][spirv] Convert `bf16` to `spirv` as a `i16`
`bf16` is not support currently in `spirv`. Current conversions
treat it as an `i16` with bit-shifting, extending, and other
manipulations. Change to `i16` appropriately.
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..87f85314afad06 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -713,6 +713,9 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
});
addConversion([this](FloatType floatType) -> std::optional<Type> {
+ if (floatType.isBF16())
+ return convertType(IntegerType::get(floatType.getContext(), 16));
+
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
return Type();
>From 25a6c920806456903cfc4d8cbccd7b3b52bc5601 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Wed, 20 Dec 2023 18:12:58 -0800
Subject: [PATCH 2/2] update test
---
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2e..0746ba979dec0e 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -211,7 +211,8 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {
-// CHECK-NOT: spirv.func @bf16_type
+// CHECK: spirv.func @bf16_type
+// CHECK-SAME: i32
func.func @bf16_type(%arg0: bf16) { return }
} // end module
More information about the Mlir-commits
mailing list