[flang-commits] [flang] [flang][Evaluate] OperationCode cleanup, fix for Constant<T> (PR #151566)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Thu Jul 31 10:41:55 PDT 2025


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/151566

Make the OperationCode overloads take the derived operation instead of the Operation base class instance. This makes them usable from visitors of "Expr<T>.u".

Also, fix small bug: OperationCode(Constant<T>) shoud be "Constant".

>From 9d032d4a4aceeabd7026c6ffe94656cf9eb91813 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 31 Jul 2025 12:29:25 -0500
Subject: [PATCH] [flang][Evaluate] OperationCode cleanup, fix for Constant<T>

Make the OperationCode overloads take the derived operation instead of
the Operation base class instance. This makes them usable from visitors
of "Expr<T>.u".

Also, fix small bug: OperationCode(Constant<T>) shoud be "Constant".
---
 flang/include/flang/Evaluate/tools.h     | 54 +++++++++---------------
 flang/lib/Evaluate/tools.cpp             |  6 +--
 flang/test/Semantics/OpenMP/atomic04.f90 |  2 +-
 flang/test/Semantics/OpenMP/atomic05.f90 |  2 +-
 4 files changed, 25 insertions(+), 39 deletions(-)

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index cef57f1851bcc..e2c9878ab19a9 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1402,10 +1402,8 @@ using OperatorSet = common::EnumSet<Operator, 32>;
 
 std::string ToString(Operator op);
 
-template <typename... Ts, int Kind>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
-  switch (op.derived().logicalOperator) {
+template <int Kind> Operator OperationCode(const LogicalOperation<Kind> &op) {
+  switch (op.logicalOperator) {
   case common::LogicalOperator::And:
     return Operator::And;
   case common::LogicalOperator::Or:
@@ -1420,10 +1418,8 @@ Operator OperationCode(
   return Operator::Unknown;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
-  switch (op.derived().opr) {
+template <typename T> Operator OperationCode(const Relational<T> &op) {
+  switch (op.opr) {
   case common::RelationalOperator::LT:
     return Operator::Lt;
   case common::RelationalOperator::LE:
@@ -1440,44 +1436,32 @@ Operator OperationCode(
   return Operator::Unknown;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
+template <typename T> Operator OperationCode(const Add<T> &op) {
   return Operator::Add;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
+template <typename T> Operator OperationCode(const Subtract<T> &op) {
   return Operator::Sub;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
+template <typename T> Operator OperationCode(const Multiply<T> &op) {
   return Operator::Mul;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
+template <typename T> Operator OperationCode(const Divide<T> &op) {
   return Operator::Div;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
+template <typename T> Operator OperationCode(const Power<T> &op) {
   return Operator::Pow;
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
+template <typename T> Operator OperationCode(const RealToIntPower<T> &op) {
   return Operator::Pow;
 }
 
-template <typename T, common::TypeCategory C, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
+template <typename T, common::TypeCategory C>
+Operator OperationCode(const Convert<T, C> &op) {
   if constexpr (C == T::category) {
     return Operator::Resize;
   } else {
@@ -1485,25 +1469,27 @@ Operator OperationCode(
   }
 }
 
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Extremum<T>, Ts...> &op) {
-  if (op.derived().ordering == evaluate::Ordering::Greater) {
+template <typename T> Operator OperationCode(const Extremum<T> &op) {
+  if (op.ordering == Ordering::Greater) {
     return Operator::Max;
   } else {
     return Operator::Min;
   }
 }
 
-template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
+template <typename T> Operator OperationCode(const Constant<T> &x) {
   return Operator::Constant;
 }
 
+template <typename T> Operator OperationCode(const Designator<T> &x) {
+  return Operator::Identity;
+}
+
 template <typename T> Operator OperationCode(const T &) {
   return Operator::Unknown;
 }
 
-Operator OperationCode(const evaluate::ProcedureDesignator &proc);
+Operator OperationCode(const ProcedureDesignator &proc);
 
 } // namespace operation
 
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 171dd91fa9fd1..90be131651697 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1693,17 +1693,17 @@ struct ArgumentExtractor
       // to int(kind=4) for example.
       return (*this)(x.template operand<0>());
     } else {
-      return std::make_pair(operation::OperationCode(x),
+      return std::make_pair(operation::OperationCode(x.derived()),
           OperationArgs(x, std::index_sequence_for<Os...>{}));
     }
   }
 
   template <typename T> Result operator()(const Designator<T> &x) const {
-    return {operation::Operator::Identity, {AsSomeExpr(x)}};
+    return {operation::OperationCode(x), {AsSomeExpr(x)}};
   }
 
   template <typename T> Result operator()(const Constant<T> &x) const {
-    return {operation::Operator::Identity, {AsSomeExpr(x)}};
+    return {operation::OperationCode(x), {AsSomeExpr(x)}};
   }
 
   template <typename... Rs>
diff --git a/flang/test/Semantics/OpenMP/atomic04.f90 b/flang/test/Semantics/OpenMP/atomic04.f90
index fb87ca5186612..8f8af31245404 100644
--- a/flang/test/Semantics/OpenMP/atomic04.f90
+++ b/flang/test/Semantics/OpenMP/atomic04.f90
@@ -180,7 +180,7 @@ subroutine more_invalid_atomic_update_stmts()
         x = x
 
     !$omp atomic update
-    !ERROR: The atomic variable x should appear as an argument in the update operation
+    !ERROR: This is not a valid ATOMIC UPDATE operation
         x = 1    
 
     !$omp atomic update
diff --git a/flang/test/Semantics/OpenMP/atomic05.f90 b/flang/test/Semantics/OpenMP/atomic05.f90
index 77ffc6e57f1a3..e0103be4cae4a 100644
--- a/flang/test/Semantics/OpenMP/atomic05.f90
+++ b/flang/test/Semantics/OpenMP/atomic05.f90
@@ -19,7 +19,7 @@ program OmpAtomic
         x = 2 * 4
     !ERROR: At most one clause from the 'memory-order' group is allowed on ATOMIC construct
     !$omp atomic update release, seq_cst
-    !ERROR: The atomic variable x should appear as an argument in the update operation
+    !ERROR: This is not a valid ATOMIC UPDATE operation
         x = 10
     !ERROR: At most one clause from the 'memory-order' group is allowed on ATOMIC construct
     !$omp atomic capture release, seq_cst



More information about the flang-commits mailing list