[Mlir-commits] [mlir] [mlir][spirv] Convert `bf16` to `spirv` as a `i16` (PR #76114)

Rob Suderman llvmlistbot at llvm.org
Wed Dec 20 18:13:14 PST 2023


https://github.com/rsuderman updated https://github.com/llvm/llvm-project/pull/76114

>From 1bddf3e07bb95d6a71da56fb66a44e678f45cc74 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Wed, 20 Dec 2023 17:18:13 -0800
Subject: [PATCH 1/2] [mlir][spirv] Convert `bf16` to `spirv` as a `i16`

`bf16` is not support currently in `spirv`. Current conversions
treat it as an `i16` with bit-shifting, extending, and other
manipulations. Change to `i16` appropriately.
---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..87f85314afad06 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -713,6 +713,9 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   });
 
   addConversion([this](FloatType floatType) -> std::optional<Type> {
+    if (floatType.isBF16())
+      return convertType(IntegerType::get(floatType.getContext(), 16));
+
     if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
       return convertScalarType(this->targetEnv, this->options, scalarType);
     return Type();

>From 25a6c920806456903cfc4d8cbccd7b3b52bc5601 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Wed, 20 Dec 2023 18:12:58 -0800
Subject: [PATCH 2/2] update test

---
 mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2e..0746ba979dec0e 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -211,7 +211,8 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
 } {
 
-// CHECK-NOT: spirv.func @bf16_type
+// CHECK: spirv.func @bf16_type
+// CHECK-SAME: i32
 func.func @bf16_type(%arg0: bf16) { return }
 
 } // end module



More information about the Mlir-commits mailing list