[Mlir-commits] [mlir] [mlir][emitc] make CExpression trait into interface (PR #142771)

Kirill Chibisov llvmlistbot at llvm.org
Mon Jun 16 03:48:26 PDT 2025


https://github.com/kchibisov updated https://github.com/llvm/llvm-project/pull/142771

>From b34f0d3576bd6a6ef48e55a504131c45f8be9817 Mon Sep 17 00:00:00 2001
From: Kirill Chibisov <contact at kchibisov.com>
Date: Wed, 4 Jun 2025 20:12:44 +0900
Subject: [PATCH] [mlir][emitc] make CExpression trait into interface

By defining `CExpressionInterface`, we move the side effect detection
logic from `emitc.expression` into the individual operations
implementing the interface allowing operations to gradually tune the
side effect detection logic.

It also allows checking for side effects of each operation individually.
---
 .../mlir/Dialect/EmitC/IR/CMakeLists.txt      |   6 +
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |   2 +-
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 107 +++++++++++-------
 .../mlir/Dialect/EmitC/IR/EmitCInterfaces.h   |  31 +++++
 .../mlir/Dialect/EmitC/IR/EmitCInterfaces.td  |  48 ++++++++
 .../mlir/Dialect/EmitC/IR/EmitCTraits.h       |  30 -----
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |   6 +-
 .../EmitC/Transforms/FormExpressions.cpp      |   2 +-
 .../Dialect/EmitC/Transforms/Transforms.cpp   |   3 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |   6 +-
 10 files changed, 159 insertions(+), 82 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.h
 create mode 100644 mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td
 delete mode 100644 mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
index 610170f5944eb..299cee76cb1b4 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
@@ -1,6 +1,12 @@
 add_mlir_dialect(EmitC emitc)
 add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc -dialect emitc)
 
+set(LLVM_TARGET_DEFINITIONS EmitCInterfaces.td)
+mlir_tablegen(EmitCInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(EmitCInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIREmitCInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIREmitCInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
 mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
 mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 57029c64ffd00..1984ed8a7f068 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -14,7 +14,7 @@
 #define MLIR_DIALECT_EMITC_IR_EMITC_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
-#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
+#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d4aea52a0d485..ebe35372c4b9b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_EMITC_IR_EMITC
 
 include "mlir/Dialect/EmitC/IR/EmitCAttributes.td"
+include "mlir/Dialect/EmitC/IR/EmitCInterfaces.td"
 include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
 
 include "mlir/Interfaces/CallInterfaces.td"
@@ -35,22 +36,31 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []>
 
 // Base class for unary operations.
 class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
-    EmitC_Op<mnemonic, traits> {
+    EmitC_Op<mnemonic, !listconcat(traits, [CExpressionInterface])> {
   let arguments = (ins EmitCType);
   let results = (outs EmitCType);
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    bool hasSideEffects() {
+      return false;
+    }
+  }];
 }
 
 // Base class for binary operations.
 class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
-    EmitC_Op<mnemonic, traits> {
+    EmitC_Op<mnemonic, !listconcat(traits, [CExpressionInterface])> {
   let arguments = (ins EmitCType:$lhs, EmitCType:$rhs);
   let results = (outs EmitCType);
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
-}
 
-// EmitC OpTrait
-def CExpression : NativeOpTrait<"emitc::CExpression">;
+  let extraClassDeclaration = [{
+    bool hasSideEffects() {
+      return false;
+    }
+  }];
+}
 
 // Types only used in binary arithmetic operations.
 def IntegerIndexOrOpaqueType : Type<CPred<"emitc::isIntegerIndexOrOpaqueType($_self)">,
@@ -103,7 +113,7 @@ def EmitC_FileOp
   let skipDefaultBuilders = 1;
 }
 
-def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
+def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
   let summary = "Addition operation";
   let description = [{
     With the `emitc.add` operation the arithmetic operator + (addition) can
@@ -126,7 +136,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
   let hasVerifier = 1;
 }
 
-def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
+def EmitC_ApplyOp : EmitC_Op<"apply", [CExpressionInterface]> {
   let summary = "Apply operation";
   let description = [{
     With the `emitc.apply` operation the operators & (address of) and * (contents of)
@@ -152,10 +162,17 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
   let assemblyFormat = [{
     $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
   }];
+
+  let extraClassDeclaration = [{
+    bool hasSideEffects() {
+      return getApplicableOperator() == "*";
+    }
+  }];
+
   let hasVerifier = 1;
 }
 
-def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", [CExpression]> {
+def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> {
   let summary = "Bitwise and operation";
   let description = [{
     With the `emitc.bitwise_and` operation the bitwise operator & (and) can
@@ -173,8 +190,7 @@ def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", [CExpression]> {
   }];
 }
 
-def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift",
-    [CExpression]> {
+def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> {
   let summary = "Bitwise left shift operation";
   let description = [{
     With the `emitc.bitwise_left_shift` operation the bitwise operator <<
@@ -192,7 +208,7 @@ def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift",
   }];
 }
 
-def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", [CExpression]> {
+def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> {
   let summary = "Bitwise not operation";
   let description = [{
     With the `emitc.bitwise_not` operation the bitwise operator ~ (not) can
@@ -210,7 +226,7 @@ def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", [CExpression]> {
   }];
 }
 
-def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", [CExpression]> {
+def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> {
   let summary = "Bitwise or operation";
   let description = [{
     With the `emitc.bitwise_or` operation the bitwise operator | (or)
@@ -228,8 +244,7 @@ def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", [CExpression]> {
   }];
 }
 
-def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift",
-    [CExpression]> {
+def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> {
   let summary = "Bitwise right shift operation";
   let description = [{
     With the `emitc.bitwise_right_shift` operation the bitwise operator >>
@@ -247,7 +262,7 @@ def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift",
   }];
 }
 
-def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", [CExpression]> {
+def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> {
   let summary = "Bitwise xor operation";
   let description = [{
     With the `emitc.bitwise_xor` operation the bitwise operator ^ (xor)
@@ -265,7 +280,7 @@ def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", [CExpression]> {
   }];
 }
 
-def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
+def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpressionInterface]> {
   let summary = "Opaque call operation";
   let description = [{
     The `emitc.call_opaque` operation represents a C++ function call. The callee
@@ -312,7 +327,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
 }
 
 def EmitC_CastOp : EmitC_Op<"cast",
-    [CExpression,
+    [CExpressionInterface,
      DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "Cast operation";
   let description = [{
@@ -335,9 +350,15 @@ def EmitC_CastOp : EmitC_Op<"cast",
   let arguments = (ins EmitCType:$source);
   let results = (outs EmitCType:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+
+  let extraClassDeclaration = [{
+    bool hasSideEffects() {
+      return false;
+    }
+  }];
 }
 
-def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
+def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
   let summary = "Comparison operation";
   let description = [{
     With the `emitc.cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> 
@@ -407,7 +428,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   let hasVerifier = 1;
 }
 
-def EmitC_DivOp : EmitC_BinaryOp<"div", [CExpression]> {
+def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
   let summary = "Division operation";
   let description = [{
     With the `emitc.div` operation the arithmetic operator / (division) can
@@ -462,7 +483,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
     ```
 
     The operations allowed within expression body are EmitC operations with the
-    CExpression trait.
+    CExpressionInterface interface.
 
     When specified, the optional `do_not_inline` indicates that the expression is
     to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
@@ -480,18 +501,8 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
   let extraClassDeclaration = [{
     bool hasSideEffects() {
       auto predicate = [](Operation &op) {
-        assert(op.hasTrait<OpTrait::emitc::CExpression>() && "Expected a C expression");
-        // Conservatively assume calls to read and write memory.
-        if (isa<emitc::CallOpaqueOp>(op))
-          return true;
-        // De-referencing reads modifiable memory, address-taking has no
-        // side-effect.
-        auto applyOp = dyn_cast<emitc::ApplyOp>(op);
-        if (applyOp)
-          return applyOp.getApplicableOperator() == "*";
-        // Any load operation is assumed to read from memory and thus perform
-        // a side effect.
-        return isa<emitc::LoadOp>(op);
+        assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
+        return cast<emitc::CExpressionInterface>(op).hasSideEffects();
       };
       return llvm::any_of(getRegion().front().without_terminator(), predicate);
     };
@@ -579,7 +590,7 @@ def EmitC_ForOp : EmitC_Op<"for",
 }
 
 def EmitC_CallOp : EmitC_Op<"call",
-    [CallOpInterface, CExpression,
+    [CallOpInterface, CExpressionInterface,
      DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "Call operation";
   let description = [{
@@ -649,6 +660,10 @@ def EmitC_CallOp : EmitC_Op<"call",
     void setCalleeFromCallable(CallInterfaceCallable callee) {
       (*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
     }
+
+    bool hasSideEffects() {
+      return false;
+    }
   }];
 
   let assemblyFormat = [{
@@ -861,7 +876,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
   let assemblyFormat = "$value attr-dict `:` type($result)";
 }
 
-def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", [CExpression]> {
+def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> {
   let summary = "Logical and operation";
   let description = [{
     With the `emitc.logical_and` operation the logical operator && (and) can
@@ -882,7 +897,7 @@ def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", [CExpression]> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", [CExpression]> {
+def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> {
   let summary = "Logical not operation";
   let description = [{
     With the `emitc.logical_not` operation the logical operator ! (negation) can
@@ -903,7 +918,7 @@ def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", [CExpression]> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
+def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> {
   let summary = "Logical or operation";
   let description = [{
     With the `emitc.logical_or` operation the logical operator || (inclusive or)
@@ -924,7 +939,7 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LoadOp : EmitC_Op<"load", [CExpression,
+def EmitC_LoadOp : EmitC_Op<"load", [CExpressionInterface,
   TypesMatchWith<"result type matches value type of 'operand'",
                   "operand", "result",
                   "::llvm::cast<LValueType>($_self).getValueType()">
@@ -953,7 +968,7 @@ def EmitC_LoadOp : EmitC_Op<"load", [CExpression,
   let assemblyFormat = "$operand attr-dict `:` type($operand)"; 
 }
 
-def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
+def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
   let summary = "Multiplication operation";
   let description = [{
     With the `emitc.mul` operation the arithmetic operator * (multiplication) can
@@ -977,7 +992,7 @@ def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
   let results = (outs FloatIntegerIndexOrOpaqueType);
 }
 
-def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> {
+def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
   let summary = "Remainder operation";
   let description = [{
     With the `emitc.rem` operation the arithmetic operator % (remainder) can
@@ -999,7 +1014,7 @@ def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> {
   let results = (outs IntegerIndexOrOpaqueType);
 }
 
-def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
+def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
   let summary = "Subtraction operation";
   let description = [{
     With the `emitc.sub` operation the arithmetic operator - (subtraction) can
@@ -1069,7 +1084,7 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
 }
 
 def EmitC_ConditionalOp : EmitC_Op<"conditional",
-    [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
+    [AllTypesMatch<["true_value", "false_value", "result"]>, CExpressionInterface]> {
   let summary = "Conditional (ternary) operation";
   let description = [{
     With the `emitc.conditional` operation the ternary conditional operator can
@@ -1096,9 +1111,15 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional",
   let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value);
   let results = (outs EmitCType:$result);
   let assemblyFormat = "operands attr-dict `:` type($result)";
+
+  let extraClassDeclaration = [{
+    bool hasSideEffects() {
+      return false;
+    }
+  }];
 }
 
-def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
+def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", []> {
   let summary = "Unary minus operation";
   let description = [{
     With the `emitc.unary_minus` operation the unary operator - (minus) can be
@@ -1116,7 +1137,7 @@ def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
   }];
 }
 
-def EmitC_UnaryPlusOp : EmitC_UnaryOp<"unary_plus", [CExpression]> {
+def EmitC_UnaryPlusOp : EmitC_UnaryOp<"unary_plus", []> {
   let summary = "Unary plus operation";
   let description = [{
     With the `emitc.unary_plus` operation the unary operator + (plus) can be
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.h
new file mode 100644
index 0000000000000..51efe76aceb5c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.h
@@ -0,0 +1,31 @@
+//===- EmitCInterfaces.h - EmitC interfaces definitions ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares C++ classes for some of the interfaces used in the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_IR_EMITCINTERFACES_H
+#define MLIR_DIALECT_EMITC_IR_EMITCINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace emitc {
+//
+} // namespace emitc
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// EmitC Dialect Interfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_EMITC_IR_EMITCINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td
new file mode 100644
index 0000000000000..777784e56202a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td
@@ -0,0 +1,48 @@
+//===- EmitCInterfaces.td - EmitC Interfaces ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interfaces used by EmitC.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_IR_EMITCINTERFACES
+#define MLIR_DIALECT_EMITC_IR_EMITCINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def CExpressionInterface : OpInterface<"CExpressionInterface"> {
+  let description = [{
+    Interface to mark operations that can be part of the CExpression.
+  }];
+
+  let cppNamespace = "::mlir::emitc";
+  let methods = [
+    InterfaceMethod<[{
+      Check whether operation has side effects that may affect the expression
+      evaluation.
+
+      By default operation is marked as having side effects.
+
+      ```c++
+      class ConcreteOp ... {
+      public:
+        bool hasSideEffects() {
+          // That way we can override the default implementation.
+          return false;
+        }
+      };
+      ```
+    }],
+      "bool", "hasSideEffects", (ins), /*methodBody=*/[{}],
+       /*defaultImplementation=*/[{
+        return true;
+    }]>,
+  ];
+}
+
+#endif // MLIR_DIALECT_EMITC_IR_EMITCINTERFACES
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
deleted file mode 100644
index c1602dfce4b48..0000000000000
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
+++ /dev/null
@@ -1,30 +0,0 @@
-//===- EmitCTraits.h - EmitC trait definitions ------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file declares C++ classes for some of the traits used in the EmitC
-// dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
-#define MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
-
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace OpTrait {
-namespace emitc {
-
-template <typename ConcreteType>
-class CExpression : public TraitBase<ConcreteType, CExpression> {};
-
-} // namespace emitc
-} // namespace OpTrait
-} // namespace mlir
-
-#endif // MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 1709654b90138..b5f86406c8891 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
+#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -412,7 +412,7 @@ LogicalResult ExpressionOp::verify() {
     return emitOpError("requires yielded type to match return type");
 
   for (Operation &op : region.front().without_terminator()) {
-    if (!op.hasTrait<OpTrait::emitc::CExpression>())
+    if (!isa<emitc::CExpressionInterface>(op))
       return emitOpError("contains an unsupported operation");
     if (op.getNumResults() != 1)
       return emitOpError("requires exactly one result for each operation");
@@ -1398,5 +1398,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
index 224d68ab8b4a6..2f3e2618f4d74 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
@@ -36,7 +36,7 @@ struct FormExpressionsPass
     // Wrap each C operator op with an expression op.
     OpBuilder builder(context);
     auto matchFun = [&](Operation *op) {
-      if (op->hasTrait<OpTrait::emitc::CExpression>() &&
+      if (isa<emitc::CExpressionInterface>(*op) &&
           !op->getParentOfType<emitc::ExpressionOp>() &&
           op->getNumResults() == 1)
         createExpression(op, builder);
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 87350ecdceaaa..a578a86b499a6 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -16,8 +16,7 @@ namespace mlir {
 namespace emitc {
 
 ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
-  assert(op->hasTrait<OpTrait::emitc::CExpression>() &&
-         "Expected a C expression");
+  assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
 
   // Create an expression yielding the value returned by op.
   assert(op->getNumResults() == 1 && "Expected exactly one result");
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 5abc112ab8c7a..067a0470b14e4 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -329,9 +329,9 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
   if (hasDeferredEmission(user))
     return false;
 
-  // Do not inline expressions used by ops with the CExpression trait. If this
-  // was intended, the user could have been merged into the expression op.
-  return !user->hasTrait<OpTrait::emitc::CExpression>();
+  // Do not inline expressions used by ops with the CExpressionInterface. If
+  // this was intended, the user could have been merged into the expression op.
+  return !isa<emitc::CExpressionInterface>(*user);
 }
 
 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,



More information about the Mlir-commits mailing list