[flang-commits] [flang] e55aa02 - [flang] Fix runtime error messages for the MATMUL intrinsic (#96928)
via flang-commits
flang-commits at lists.llvm.org
Thu Jun 27 14:54:07 PDT 2024
Author: Pete Steinfeld
Date: 2024-06-27T14:54:02-07:00
New Revision: e55aa027f813679ca63c9b803690ce792a3d7b28
URL: https://github.com/llvm/llvm-project/commit/e55aa027f813679ca63c9b803690ce792a3d7b28
DIFF: https://github.com/llvm/llvm-project/commit/e55aa027f813679ca63c9b803690ce792a3d7b28.diff
LOG: [flang] Fix runtime error messages for the MATMUL intrinsic (#96928)
There are three forms of MATMUL -- where the first argument is a rank 1
array, where the second argument is a rank 1 array, and where both
arguments are rank 2 arrays. There's code in the runtime that detects
when the array shapes are incorrect. But the code that emits an error
message assumes that both arguments are rank 2 arrays.
This change contains code for the other two cases.
Added:
Modified:
flang/runtime/matmul.cpp
Removed:
################################################################################
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index 543284cb5c363..8f9b50a549e1f 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -288,11 +288,25 @@ static inline RT_API_ATTRS void DoMatmul(
}
SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
if (n != y.GetDimension(0).Extent()) {
- terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
- static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
- static_cast<std::intmax_t>(n),
- static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
- static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
+ // At this point, we know that there's a shape error. There are three
+ // possibilities, x is rank 1, y is rank 1, or both are rank 2.
+ if (xRank == 1) {
+ terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
+ static_cast<std::intmax_t>(n),
+ static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
+ static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
+ } else if (yRank == 1) {
+ terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
+ static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
+ static_cast<std::intmax_t>(n),
+ static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
+ } else {
+ terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
+ static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
+ static_cast<std::intmax_t>(n),
+ static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
+ static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
+ }
}
using WriteResult =
CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
More information about the flang-commits
mailing list