<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/63424>63424</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            Support FP16/BF16 in MLIR TOSA (Half-precision Tensors and ops)
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          guoqingbao
      </td>
    </tr>
</table>

<pre>
    We've intended to lower DL models to MLIR TOSA, and **we found MLIR does not have the full support of half-precision ops**, for example, AvgPool2dOp in TOSA only accepts fp32, int8 and int64. We've tried to use tosa.CastOp to convert fp16/bf16 to fp32, but still, CastOp has no such support for half-precision inputs. Given that half-precision training and inference is very important for large language models, **is there any plan in MLIR to support half-precision Tensors and ops?**

``` C++
LogicalResult tosa::AvgPool2dOp::verify() {
  auto inputETy = llvm::cast<ShapedType>(getInput().getType()).getElementType();
  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();

  if (auto quantType =
 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
    inputETy = quantType.getStorageType();

  if (auto quantType =
 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
 resultETy = quantType.getStorageType();

  auto accType = getAccType();
  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
    return emitOpError("accumulator type for integer tensor is not i32");

  if ((inputETy.isBF16() || inputETy.isF16()) &&
      !(accType.isF16() || accType.isF32()))
 return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");

  if (inputETy.isF32() && !accType.isF32())
    return emitOpError("accumulator type for f32 tensor is not f32");

  if (inputETy.isF32() && resultETy.isF32())
    return success();
  if (inputETy.isInteger(8) && resultETy.isInteger(8))
    return success();
  if (inputETy.isInteger(16) && resultETy.isInteger(16))
 return success();

  **return emitOpError("input/output element types are incompatible.");** **//Error for FP16 and BF16 inputs.**
}
```
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzEVstu6zYQ_Rp6M4ghU45kLbxwnLgNkCK3Nym6LChpJLGgSF0-fOt-fUFS8iOPBmkLFBASc0jOOWcOX8wY3krENbm-Ide3M-Zsp_S6deobl23J1KxU9WH9KxKa7xG4tChrrMEqEOo7arh9gF7VKIwP_fRw_xWeH582hG6ByRoI3RC6-Y7QKCfr2F8rNCCVhY7tEWyH0DghwLhhUNqCaqBjorkaNFbccCVBDSbm8VkbpQH_YP0g0Dc3-_aLUoLWjwNwGbBBSXEAVlU4WAPNkFI_kEu7CpS4tNlyDpMiq3mU4wyCVYbNt8zYx8GHKiX3qC00wyIjdFc2i8yHp5Sls2AsF8I3xlkd89rAuKo7KvKUX0jicnDWzOEHvkcJtmP25QirGZdctiPnBjXKCoEb2KM-AO99aiZjdsF0iyCYbB1rcTTEs4p148aXWSMweYBBMI8fvbDqyPIF_jNKo7QJ8N6AdDd6kNySZPqbJfGDLaE3_gvhB9XyiomvaJywoaYk3ZB0c-ZVDOxR8-ZA6IrQAkg-TgdgzqpYorvnA5D0FoTY93FOxYwl6fapYwPWz4cBSXpH6KpFe-8nxGTzFm3oC60xcCewR3keTy8RdSD8GchPYExIvAFCVwHvm2NxqMcb-0-o9UH-NiL3gusYDFPiz18kb5Tuf_YR_uc5salykdYEDJcVPYJ72k9Wadbi_0v7WP4L3pemfJJ14MuqamILLdpNbL5aAlHhSQj3y3Z7Ly22qN8sLhCaEZoBoYsRY87NOIHQlT8kLg3QaJ2WgD23j8Od1koHFpRVleudYFZpsJ6q39M8JgIbNqLf-P7M5D4rfd-hM4Zzbm52_uAat9eW5Fs46zz2naScqIJX5R0_CnuV6qzLkxoznTv3CbXN2Ql7ITh2NB_IPpc1kXnTnguq_8yYJqUvOf4Ldsf1_RE546oKjXln3Z5BnFbgeziXI_4jMO_TB2hhyKsF8jbWhBhvnfcM4vHM3ylnB2cB4_EbvDLAtH-wVKofmOWlwPnJpJAUplfFjtBdSBr83X1ZZOHa87tnuqkvLr_89sX9N6vXaV2kBZvhepGt8jRN8qyYdesCq2JR0OuyKMs6Z80yybM0rWpaJ4uszlczvqYJTZOMLpLsuqDFPE-Kpsir6yRjeb4oM7JMsGdczP3JNFe6nXFjHK6zdEmXM8FKFGZ6vOm1H3RVutaQZSK4seY0zXIrcP003vZeJKG7UeLp4ebt_fHvXwK0mDkt1p21_lkwlq_ltnPlvFI9obtwiMZ_V4NWv2PlPQq8DaG7QP2vAAAA__-TujJO">