<table border="1" cellspacing="0" cellpadding="8">
<tr>
<th>Issue</th>
<td>
<a href=https://github.com/llvm/llvm-project/issues/63424>63424</a>
</td>
</tr>
<tr>
<th>Summary</th>
<td>
Support FP16/BF16 in MLIR TOSA (Half-precision Tensors and ops)
</td>
</tr>
<tr>
<th>Labels</th>
<td>
</td>
</tr>
<tr>
<th>Assignees</th>
<td>
</td>
</tr>
<tr>
<th>Reporter</th>
<td>
guoqingbao
</td>
</tr>
</table>
<pre>
We've intended to lower DL models to MLIR TOSA, and **we found MLIR does not have the full support of half-precision ops**, for example, AvgPool2dOp in TOSA only accepts fp32, int8 and int64. We've tried to use tosa.CastOp to convert fp16/bf16 to fp32, but still, CastOp has no such support for half-precision inputs. Given that half-precision training and inference is very important for large language models, **is there any plan in MLIR to support half-precision Tensors and ops?**
``` C++
LogicalResult tosa::AvgPool2dOp::verify() {
auto inputETy = llvm::cast<ShapedType>(getInput().getType()).getElementType();
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
resultETy = quantType.getStorageType();
auto accType = getAccType();
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
return emitOpError("accumulator type for integer tensor is not i32");
if ((inputETy.isBF16() || inputETy.isF16()) &&
!(accType.isF16() || accType.isF32()))
return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
if (inputETy.isF32() && resultETy.isF32())
return success();
if (inputETy.isInteger(8) && resultETy.isInteger(8))
return success();
if (inputETy.isInteger(16) && resultETy.isInteger(16))
return success();
**return emitOpError("input/output element types are incompatible.");** **//Error for FP16 and BF16 inputs.**
}
```
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzEVstu6zYQ_Rp6M4ghU45kLbxwnLgNkCK3Nym6LChpJLGgSF0-fOt-fUFS8iOPBmkLFBASc0jOOWcOX8wY3krENbm-Ide3M-Zsp_S6deobl23J1KxU9WH9KxKa7xG4tChrrMEqEOo7arh9gF7VKIwP_fRw_xWeH582hG6ByRoI3RC6-Y7QKCfr2F8rNCCVhY7tEWyH0DghwLhhUNqCaqBjorkaNFbccCVBDSbm8VkbpQH_YP0g0Dc3-_aLUoLWjwNwGbBBSXEAVlU4WAPNkFI_kEu7CpS4tNlyDpMiq3mU4wyCVYbNt8zYx8GHKiX3qC00wyIjdFc2i8yHp5Sls2AsF8I3xlkd89rAuKo7KvKUX0jicnDWzOEHvkcJtmP25QirGZdctiPnBjXKCoEb2KM-AO99aiZjdsF0iyCYbB1rcTTEs4p148aXWSMweYBBMI8fvbDqyPIF_jNKo7QJ8N6AdDd6kNySZPqbJfGDLaE3_gvhB9XyiomvaJywoaYk3ZB0c-ZVDOxR8-ZA6IrQAkg-TgdgzqpYorvnA5D0FoTY93FOxYwl6fapYwPWz4cBSXpH6KpFe-8nxGTzFm3oC60xcCewR3keTy8RdSD8GchPYExIvAFCVwHvm2NxqMcb-0-o9UH-NiL3gusYDFPiz18kb5Tuf_YR_uc5salykdYEDJcVPYJ72k9Wadbi_0v7WP4L3pemfJJ14MuqamILLdpNbL5aAlHhSQj3y3Z7Ly22qN8sLhCaEZoBoYsRY87NOIHQlT8kLg3QaJ2WgD23j8Od1koHFpRVleudYFZpsJ6q39M8JgIbNqLf-P7M5D4rfd-hM4Zzbm52_uAat9eW5Fs46zz2naScqIJX5R0_CnuV6qzLkxoznTv3CbXN2Ql7ITh2NB_IPpc1kXnTnguq_8yYJqUvOf4Ldsf1_RE546oKjXln3Z5BnFbgeziXI_4jMO_TB2hhyKsF8jbWhBhvnfcM4vHM3ylnB2cB4_EbvDLAtH-wVKofmOWlwPnJpJAUplfFjtBdSBr83X1ZZOHa87tnuqkvLr_89sX9N6vXaV2kBZvhepGt8jRN8qyYdesCq2JR0OuyKMs6Z80yybM0rWpaJ4uszlczvqYJTZOMLpLsuqDFPE-Kpsir6yRjeb4oM7JMsGdczP3JNFe6nXFjHK6zdEmXM8FKFGZ6vOm1H3RVutaQZSK4seY0zXIrcP003vZeJKG7UeLp4ebt_fHvXwK0mDkt1p21_lkwlq_ltnPlvFI9obtwiMZ_V4NWv2PlPQq8DaG7QP2vAAAA__-TujJO">