[flang-commits] [flang] [flang] Lower MATMUL to type specific runtime calls. (PR #97547)

via flang-commits flang-commits at lists.llvm.org
Wed Jul 3 02:41:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

Lower MATMUL to the new runtime entries added in #<!-- -->97406.


---

Patch is 53.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97547.diff


15 Files Affected:

- (modified) flang/include/flang/Optimizer/Support/Utils.h (+73-4) 
- (modified) flang/include/flang/Runtime/matmul-instances.inc (+14-9) 
- (modified) flang/include/flang/Runtime/matmul-transpose.h (+2) 
- (modified) flang/include/flang/Runtime/matmul.h (+2) 
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+8-7) 
- (modified) flang/lib/Optimizer/Builder/Runtime/Transformational.cpp (+92-4) 
- (modified) flang/runtime/matmul-transpose.cpp (+2-51) 
- (modified) flang/runtime/matmul.cpp (+2-51) 
- (modified) flang/test/HLFIR/matmul-lowering.fir (+3-3) 
- (modified) flang/test/HLFIR/mul_transpose.f90 (+3-3) 
- (modified) flang/test/Lower/Intrinsics/matmul.f90 (+2-2) 
- (modified) flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h (+9) 
- (modified) flang/unittests/Optimizer/Builder/Runtime/TransformationalTest.cpp (+34-8) 
- (modified) flang/unittests/Runtime/Matmul.cpp (-119) 
- (modified) flang/unittests/Runtime/MatmulTranspose.cpp (-131) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index ae95a26be1d86..2ffb48335686c 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -84,9 +84,10 @@ inline std::string mlirTypeToString(mlir::Type type) {
   return result;
 }
 
-inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
-                                            mlir::Type type, mlir::Location loc,
-                                            const llvm::Twine &name) {
+inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
+                                              mlir::Type type,
+                                              mlir::Location loc,
+                                              const llvm::Twine &name) {
   if (type.isF16())
     return "REAL(KIND=2)";
   else if (type.isBF16())
@@ -123,6 +124,14 @@ inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
     return "COMPLEX(KIND=10)";
   else if (type == fir::ComplexType::get(builder.getContext(), 16))
     return "COMPLEX(KIND=16)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 1))
+    return "LOGICAL(KIND=1)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 2))
+    return "LOGICAL(KIND=2)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 4))
+    return "LOGICAL(KIND=4)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 8))
+    return "LOGICAL(KIND=8)";
   else
     fir::emitFatalError(loc, "unsupported type in " + name + ": " +
                                  fir::mlirTypeToString(type));
@@ -133,10 +142,70 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
                               const llvm::Twine &intrinsicName) {
   TODO(loc,
        "intrinsic: " +
-           fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
+           fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
            " in " + intrinsicName);
 }
 
+inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
+                               mlir::Type type2, mlir::Location loc,
+                               const llvm::Twine &intrinsicName) {
+  TODO(loc,
+       "intrinsic: {" +
+           fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
+           ", " +
+           fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
+           "} in " + intrinsicName);
+}
+
+inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
+mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
+  if (type.isF16())
+    return {Fortran::common::TypeCategory::Real, 2};
+  else if (type.isBF16())
+    return {Fortran::common::TypeCategory::Real, 3};
+  else if (type.isF32())
+    return {Fortran::common::TypeCategory::Real, 4};
+  else if (type.isF64())
+    return {Fortran::common::TypeCategory::Real, 8};
+  else if (type.isF80())
+    return {Fortran::common::TypeCategory::Real, 10};
+  else if (type.isF128())
+    return {Fortran::common::TypeCategory::Real, 16};
+  else if (type.isInteger(8))
+    return {Fortran::common::TypeCategory::Integer, 1};
+  else if (type.isInteger(16))
+    return {Fortran::common::TypeCategory::Integer, 2};
+  else if (type.isInteger(32))
+    return {Fortran::common::TypeCategory::Integer, 4};
+  else if (type.isInteger(64))
+    return {Fortran::common::TypeCategory::Integer, 8};
+  else if (type.isInteger(128))
+    return {Fortran::common::TypeCategory::Integer, 16};
+  else if (type == fir::ComplexType::get(loc.getContext(), 2))
+    return {Fortran::common::TypeCategory::Complex, 2};
+  else if (type == fir::ComplexType::get(loc.getContext(), 3))
+    return {Fortran::common::TypeCategory::Complex, 3};
+  else if (type == fir::ComplexType::get(loc.getContext(), 4))
+    return {Fortran::common::TypeCategory::Complex, 4};
+  else if (type == fir::ComplexType::get(loc.getContext(), 8))
+    return {Fortran::common::TypeCategory::Complex, 8};
+  else if (type == fir::ComplexType::get(loc.getContext(), 10))
+    return {Fortran::common::TypeCategory::Complex, 10};
+  else if (type == fir::ComplexType::get(loc.getContext(), 16))
+    return {Fortran::common::TypeCategory::Complex, 16};
+  else if (type == fir::LogicalType::get(loc.getContext(), 1))
+    return {Fortran::common::TypeCategory::Logical, 1};
+  else if (type == fir::LogicalType::get(loc.getContext(), 2))
+    return {Fortran::common::TypeCategory::Logical, 2};
+  else if (type == fir::LogicalType::get(loc.getContext(), 4))
+    return {Fortran::common::TypeCategory::Logical, 4};
+  else if (type == fir::LogicalType::get(loc.getContext(), 8))
+    return {Fortran::common::TypeCategory::Logical, 8};
+  else
+    fir::emitFatalError(loc,
+                        "unsupported type: " + fir::mlirTypeToString(type));
+}
+
 /// Find the fir.type_info that was created for this \p recordType in \p module,
 /// if any. \p  symbolTable can be provided to speed-up the lookup. This tool
 /// will match record type even if they have been "altered" in type conversion
diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
index 970b03339cd5e..32c6ab06d2521 100644
--- a/flang/include/flang/Runtime/matmul-instances.inc
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -17,6 +17,10 @@
 #error "Define MATMUL_DIRECT_INSTANCE before including this file"
 #endif
 
+#ifndef MATMUL_FORCE_ALL_TYPES
+#error "Define MATMUL_FORCE_ALL_TYPES to 0 or 1 before including this file"
+#endif
+
 // clang-format off
 
 #define FOREACH_MATMUL_TYPE_PAIR(macro)         \
@@ -88,7 +92,7 @@
 FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
 
-#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#if MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
 #define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro)      \
   macro(Integer, 16, Integer, 1)                        \
   macro(Integer, 16, Integer, 2)                        \
@@ -107,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
 
-#if LDBL_MANT_DIG == 64
+#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
 MATMUL_INSTANCE(Integer, 16, Real, 10)
 MATMUL_INSTANCE(Integer, 16, Complex, 10)
 MATMUL_INSTANCE(Real, 10, Integer, 16)
@@ -117,7 +121,7 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
 MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
 MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
 #endif
-#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
 MATMUL_INSTANCE(Integer, 16, Real, 16)
 MATMUL_INSTANCE(Integer, 16, Complex, 16)
 MATMUL_INSTANCE(Real, 16, Integer, 16)
@@ -127,9 +131,9 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
 MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
 MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
 #endif
-#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
 
-#if LDBL_MANT_DIG == 64
+#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
 #define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro)         \
   macro(Integer, 1, Real, 10)                               \
   macro(Integer, 1, Complex, 10)                            \
@@ -171,7 +175,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
 
-#if HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT128
 MATMUL_INSTANCE(Real, 10, Real, 16)
 MATMUL_INSTANCE(Real, 10, Complex, 16)
 MATMUL_INSTANCE(Real, 16, Real, 10)
@@ -189,9 +193,9 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
 MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
 MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
 #endif
-#endif // LDBL_MANT_DIG == 64
+#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
 
-#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
 #define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro)         \
   macro(Integer, 1, Real, 16)                               \
   macro(Integer, 1, Complex, 16)                            \
@@ -232,7 +236,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
 
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
-#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#endif // MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
 
 #define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
   macro(Logical, 1, Logical, 1)                 \
@@ -257,5 +261,6 @@ FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
 
 #undef MATMUL_INSTANCE
 #undef MATMUL_DIRECT_INSTANCE
+#undef MATMUL_FORCE_ALL_TYPES
 
 // clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index d0a5005a1a8bd..2d79ca10e0895 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -40,6 +40,8 @@ void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
       Descriptor & result, const Descriptor &x, const Descriptor &y, \
       const char *sourceFile, int line);
 
+#define MATMUL_FORCE_ALL_TYPES 0
+
 #include "matmul-instances.inc"
 
 } // extern "C"
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 1a5e39eb8813f..a72d4a06ee459 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -39,6 +39,8 @@ void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
       const Descriptor &x, const Descriptor &y, const char *sourceFile, \
       int line);
 
+#define MATMUL_FORCE_ALL_TYPES 0
+
 #include "matmul-instances.inc"
 
 } // extern "C"
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8dd1904939f3e..a1cef7437fa2d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -701,18 +701,19 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc,
   if (name == "pow") {
     assert(funcType.getNumInputs() == 2 && "power operator has two arguments");
     std::string displayName{" ** "};
-    sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
-                                        displayName)
+    sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
+                                          displayName)
             << displayName
-            << numericMlirTypeToFortran(builder, funcType.getInput(1), loc,
-                                        displayName);
+            << mlirTypeToIntrinsicFortran(builder, funcType.getInput(1), loc,
+                                          displayName);
   } else {
     sstream << name.upper() << "(";
     if (funcType.getNumInputs() > 0)
-      sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
-                                          name);
+      sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
+                                            name);
     for (mlir::Type argType : funcType.getInputs().drop_front()) {
-      sstream << ", " << numericMlirTypeToFortran(builder, argType, loc, name);
+      sstream << ", "
+              << mlirTypeToIntrinsicFortran(builder, argType, loc, name);
     }
     sstream << ")";
   }
diff --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
index 6d3d85e8df69f..8f08b01fe0097 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
@@ -329,11 +329,64 @@ void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder,
   builder.create<fir::CallOp>(loc, eoshiftFunc, args);
 }
 
+/// Define ForcedMatmul<ACAT><AKIND><BCAT><BKIND> models.
+struct ForcedMatmulTypeModel {
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto boxRefTy =
+          fir::runtime::getModel<Fortran::runtime::Descriptor &>()(ctx);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto strTy = fir::runtime::getModel<const char *>()(ctx);
+      auto intTy = fir::runtime::getModel<int>()(ctx);
+      auto voidTy = fir::runtime::getModel<void>()(ctx);
+      return mlir::FunctionType::get(
+          ctx, {boxRefTy, boxTy, boxTy, strTy, intTy}, {voidTy});
+    };
+  }
+};
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  struct ForcedMatmul##ACAT##AKIND##BCAT##BKIND                                \
+      : public ForcedMatmulTypeModel {                                         \
+    static constexpr const char *name =                                        \
+        ExpandAndQuoteKey(RTNAME(Matmul##ACAT##AKIND##BCAT##BKIND));           \
+  };
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+
+#include "flang/Runtime/matmul-instances.inc"
+
 /// Generate call to Matmul intrinsic runtime routine.
 void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
                              mlir::Value resultBox, mlir::Value matrixABox,
                              mlir::Value matrixBBox) {
-  auto func = fir::runtime::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
+  mlir::func::FuncOp func;
+  auto boxATy = matrixABox.getType();
+  auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
+  auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
+  auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
+  auto boxBTy = matrixBBox.getType();
+  auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
+  auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
+  auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  if (!func && aCat == TypeCategory::ACAT && aKind == AKIND &&                 \
+      bCat == TypeCategory::BCAT && bKind == BKIND) {                          \
+    func =                                                                     \
+        fir::runtime::getRuntimeFunc<ForcedMatmul##ACAT##AKIND##BCAT##BKIND>(  \
+            loc, builder);                                                     \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+#include "flang/Runtime/matmul-instances.inc"
+
+  if (!func) {
+    fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc, "MATMUL");
+  }
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
   auto sourceLine =
@@ -344,13 +397,48 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
   builder.create<fir::CallOp>(loc, func, args);
 }
 
-/// Generate call to MatmulTranspose intrinsic runtime routine.
+/// Define ForcedMatmulTranspose<ACAT><AKIND><BCAT><BKIND> models.
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  struct ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND                       \
+      : public ForcedMatmulTypeModel {                                         \
+    static constexpr const char *name =                                        \
+        ExpandAndQuoteKey(RTNAME(MatmulTranspose##ACAT##AKIND##BCAT##BKIND));  \
+  };
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+
+#include "flang/Runtime/matmul-instances.inc"
+
 void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
                                       mlir::Location loc, mlir::Value resultBox,
                                       mlir::Value matrixABox,
                                       mlir::Value matrixBBox) {
-  auto func =
-      fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+  mlir::func::FuncOp func;
+  auto boxATy = matrixABox.getType();
+  auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
+  auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
+  auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
+  auto boxBTy = matrixBBox.getType();
+  auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
+  auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
+  auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  if (!func && aCat == TypeCategory::ACAT && aKind == AKIND &&                 \
+      bCat == TypeCategory::BCAT && bKind == BKIND) {                          \
+    func = fir::runtime::getRuntimeFunc<                                       \
+        ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND>(loc, builder);        \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+#include "flang/Runtime/matmul-instances.inc"
+
+  if (!func) {
+    fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc,
+                            "MATMUL-TRANSPOSE");
+  }
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
   auto sourceLine =
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index 1c998fa8cf6c1..283472650a1c6 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -343,48 +343,6 @@ inline static RT_API_ATTRS void DoMatmulTranspose(
 
 RT_DIAG_POP
 
-// Maps the dynamic type information from the arguments' descriptors
-// to the right instantiation of DoMatmul() for valid combinations of
-// types.
-template <bool IS_ALLOCATING> struct MatmulTranspose {
-  using ResultDescriptor =
-      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
-  template <TypeCategory XCAT, int XKIND> struct MM1 {
-    template <TypeCategory YCAT, int YKIND> struct MM2 {
-      RT_API_ATTRS void operator()(ResultDescriptor &result,
-          const Descriptor &x, const Descriptor &y,
-          Terminator &terminator) const {
-        if constexpr (constexpr auto resultType{
-                          GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
-          if constexpr (Fortran::common::IsNumericTypeCategory(
-                            resultType->first) ||
-              resultType->first == TypeCategory::Logical) {
-            return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
-                resultType->second, CppTypeFor<XCAT, XKIND>,
-                CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
-          }
-        }
-        terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
-            static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
-      }
-    };
-    RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
-        const Descriptor &y, Terminator &terminator, TypeCategory yCat,
-        int yKind) const {
-      ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
-    }
-  };
-  RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
-      const Descriptor &y, const char *sourceFile, int line) const {
-    Terminator terminator{sourceFile, line};
-    auto xCatKind{x.type().GetCategoryAndKind()};
-    auto yCatKind{y.type().GetCategoryAndKind()};
-    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
-    ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
-        x, y, terminator, yCatKind->first, yCatKind->second);
-  }
-};
-
 template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
     int YKIND>
 struct MatmulTransposeHelper {
@@ -414,15 +372,6 @@ namespace Fortran::runtime {
 extern "C" {
 RT_EXT_API_GROUP_BEGIN
 
-void RTDEF(MatmulTranspose)(Descriptor &result, const Descriptor &x,
-    const Descriptor ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/97547


More information about the flang-commits mailing list