[flang-commits] [flang] [flang] Fix runtime error messages for the MATMUL intrinsic (PR #96928)

Pete Steinfeld via flang-commits flang-commits at lists.llvm.org
Thu Jun 27 09:34:09 PDT 2024


https://github.com/psteinfeld created https://github.com/llvm/llvm-project/pull/96928

There are three forms of MATMUL -- where the first argument is a rank 1 array, where the second argument is a rank 1 array, and where both arguments are rank 2 arrays.  There's code in the runtime that detects when the array shapes are incorrect.  But the code that emits an error message assumes that both arguments are rank 2 arrays.

This change contains code for the other two cases.

>From 313670479413d1ca7499fdf01d6472bdc8a8e2a6 Mon Sep 17 00:00:00 2001
From: Peter Steinfeld <psteinfeld at nvidia.com>
Date: Thu, 27 Jun 2024 09:24:47 -0700
Subject: [PATCH] [flang] Fix runtime error messages for the MATMUL intrinsic

There are three forms of MATMUL -- where the first argument is a rank 1
array, where the second argument is a rank 1 array, and where both
arguments are rank 2 arrays.  There's code in the runtime that detects
when the array shapes are incorrect.  But the code that emits an error
message assumes that both arguments are rank 2 arrays.

This change contains code for the other two cases.
---
 flang/runtime/matmul.cpp | 24 +++++++++++++++++++-----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index 543284cb5c363..8f9b50a549e1f 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -288,11 +288,25 @@ static inline RT_API_ATTRS void DoMatmul(
   }
   SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
   if (n != y.GetDimension(0).Extent()) {
-    terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
-        static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
-        static_cast<std::intmax_t>(n),
-        static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
-        static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
+    // At this point, we know that there's a shape error.  There are three
+    // possibilities, x is rank 1, y is rank 1, or both are rank 2.
+    if (xRank == 1) {
+      terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
+          static_cast<std::intmax_t>(n),
+          static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
+          static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
+    } else if (yRank == 1) {
+      terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
+          static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
+          static_cast<std::intmax_t>(n),
+          static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
+    } else {
+      terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
+          static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
+          static_cast<std::intmax_t>(n),
+          static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
+          static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
+    }
   }
   using WriteResult =
       CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,



More information about the flang-commits mailing list