[llvm] X86: Improve cost model of fp16 conversion (PR #113195)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 29 09:42:16 PDT 2024


================
@@ -3146,6 +3174,11 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
                             TTI::CastContextHint::None, CostKind);
   }
 
+  if (ISD == ISD::FP_ROUND && LTDest.second.getScalarType() == MVT::f16) {
+    // Conversion requires a libcall.
+    return InstructionCost::getInvalid();
----------------
JoelWee wrote:

This is breaking https://github.com/google/jax/blob/main/tests/lax_test.py#L3630 `LazyConstantTest.testConvertElementTypeAvoidsCopies21 (dtype_in=<class 'numpy.float64'>, dtype_out=<class 'numpy.float16'>)`.

With 
```
F1029 08:45:30.640847    4013 logging.cc:62] assert.h assertion failed at [third_party/llvm/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp:4569](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp?l=4569&ws=joelwee/4894&snapshot=26) in VectorizationFactor llvm::LoopVectorizationPlanner::selectVectorizationFactor(): ExpectedCost.isValid() && "Unexpected invalid cost for scalar loop"
*** Check failure stack trace: ***
    @     0x7ef66f09cf59  absl::log_internal::LogMessage::SendToLog()
    @     0x7ef66f09c4fe  absl::log_internal::LogMessage::Flush()
    @     0x7ef66f09d519  absl::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7ef67ade7314  __assert_fail
    @     0x7efa86da3f10  llvm::LoopVectorizationPlanner::selectVectorizationFactor()
    @     0x7efa86db95df  llvm::LoopVectorizationPlanner::computeBestVF()
    @     0x7efa86dcbdfd  llvm::LoopVectorizePass::processLoop()
    @     0x7efa86dd2c3d  llvm::LoopVectorizePass::runImpl()
    @     0x7efa86dd3875  llvm::LoopVectorizePass::run()
    @     0x7efa8ceb7332  llvm::detail::PassModel<>::run()
    @     0x7ef9520b9050  llvm::PassManager<>::run()
    @     0x7efaff179412  llvm::detail::PassModel<>::run()
    @     0x7ef9520be28a  llvm::ModuleToFunctionPassAdaptor::run()
    @     0x7efaff179192  llvm::detail::PassModel<>::run()
    @     0x7ef9520b7d7c  llvm::PassManager<>::run()
    @     0x7efaa12ec861  xla::cpu::CompilerFunctor::operator()()
    @     0x7efa913b0271  llvm::orc::ThreadSafeModule::withModuleDo<>()
    @     0x7efa913b000b  llvm::orc::IRCompileLayer::emit()
    @     0x7efa913e6d45  llvm::orc::BasicIRLayerMaterializationUnit::materialize()
    @     0x7efa91454337  llvm::orc::InPlaceTaskDispatcher::dispatch()
    @     0x7efa91349466  llvm::orc::ExecutionSession::dispatchOutstandingMUs()
    @     0x7efa9134e9e6  llvm::orc::ExecutionSession::OL_completeLookup()
    @     0x7efa91369a89  llvm::orc::InProgressFullLookupState::complete()
    @     0x7efa9133a0f0  llvm::orc::ExecutionSession::OL_applyQueryPhase1()
    @     0x7efa91337234  llvm::orc::ExecutionSession::lookup()
    @     0x7efa9134991e  llvm::orc::ExecutionSession::lookup()
    @     0x7efa91349de8  llvm::orc::ExecutionSession::lookup()
    @     0x7efa9134a30e  llvm::orc::ExecutionSession::lookup()
    @     0x7efa9134a459  llvm::orc::ExecutionSession::lookup()
    @     0x7efaa1719abf  xla::cpu::SimpleOrcJIT::FindCompiledSymbol()
    @     0x7efaddc247c0  absl::internal_any_invocable::RemoteInvoker<>()
    @     0x7efaddc0fb68  std::__u::__function::__policy_invoker<>::__call_impl<>()
    @     0x7ef89847e1b6  tsl::thread::EigenEnvironment::ExecuteTask()
    @     0x7ef89847dd10  Eigen::ThreadPoolTempl<>::WorkerLoop()
    @     0x7ef89847d940  std::__u::invoke<>()
    @     0x7ef6a5f9e25e  Thread::ThreadBody()
    @     0x7efafb6827db  start_thread
    @     0x7efabc18e05f  clone
```


I dumped the LLVM IR before:

```
; Function Attrs: nofree norecurse nosync nounwind memory(readwrite, inaccessiblemem: none) uwtable
define noalias noundef ptr @convert.2(ptr nocapture readonly %0) local_unnamed_addr #0 {
  %args_gep = getelementptr inbounds nuw i8, ptr %0, i64 24
  %args = load ptr, ptr %args_gep, align 8
  %arg0 = load ptr, ptr %args, align 8, !invariant.load !0, !dereferenceable !1, !align !2
  %arg1_gep = getelementptr i8, ptr %args, i64 16
  %arg1 = load ptr, ptr %arg1_gep, align 8, !invariant.load !0, !dereferenceable !3, !align !2
  br label %convert.2.loop_body.dim.0

convert.2.loop_body.dim.0:                        ; preds = %1, %convert.2.loop_body.dim.0
  %convert.2.invar_address.dim.0.03 = phi i64 [ 0, %1 ], [ %invar.inc, %convert.2.loop_body.dim.0 ]
  %2 = getelementptr inbounds [5 x double], ptr %arg0, i64 0, i64 %convert.2.invar_address.dim.0.03
  %3 = load double, ptr %2, align 8, !invariant.load !0, !noalias !4
  %4 = fptrunc double %3 to half
  %5 = getelementptr inbounds [5 x half], ptr %arg1, i64 0, i64 %convert.2.invar_address.dim.0.03
  store half %4, ptr %5, align 2, !alias.scope !4
  %invar.inc = add nuw nsw i64 %convert.2.invar_address.dim.0.03, 1
  %exitcond = icmp eq i64 %invar.inc, 5
  br i1 %exitcond, label %return, label %convert.2.loop_body.dim.0

return:                                           ; preds = %convert.2.loop_body.dim.0
  ret ptr null
}
```

Could we fix this?

https://github.com/llvm/llvm-project/pull/113195


More information about the llvm-commits mailing list