[Mlir-commits] [mlir] [mlir][EmitC] Introduce a `CExpression` trait (PR #84177)

Marius Brehler llvmlistbot at llvm.org
Wed Mar 6 06:53:21 PST 2024


https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/84177

>From 66bef3882abbebeb58472622b220ecbf8c821f4a Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Wed, 6 Mar 2024 14:26:42 +0000
Subject: [PATCH] [mlir][EmitC] Introduce a `CExpression` trait

This adds a `CExpression` trait and replaces the `isCExpression()`
function.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |  1 +
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 61 +++++++++----------
 .../mlir/Dialect/EmitC/IR/EmitCTraits.h       | 30 +++++++++
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  3 +-
 .../EmitC/Transforms/FormExpressions.cpp      |  2 +-
 .../Dialect/EmitC/Transforms/Transforms.cpp   |  3 +-
 6 files changed, 66 insertions(+), 34 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 3d38744527d599..1f0df3cb336b12 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_EMITC_IR_EMITC_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 6bef395e94eb9d..db0e2d10960d72 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -47,11 +47,14 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
+// EmitC OpTrait
+def CExpression : NativeOpTrait<"emitc::CExpression">;
+
 // Types only used in binary arithmetic operations.
 def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
 def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
 
-def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
+def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
   let summary = "Addition operation";
   let description = [{
     With the `add` operation the arithmetic operator + (addition) can
@@ -74,7 +77,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_ApplyOp : EmitC_Op<"apply", []> {
+def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
   let summary = "Apply operation";
   let description = [{
     With the `apply` operation the operators & (address of) and * (contents of)
@@ -103,7 +106,7 @@ def EmitC_ApplyOp : EmitC_Op<"apply", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> {
+def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", [CExpression]> {
   let summary = "Bitwise and operation";
   let description = [{
     With the `bitwise_and` operation the bitwise operator & (and) can
@@ -121,7 +124,8 @@ def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> {
   }];
 }
 
-def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> {
+def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift",
+    [CExpression]> {
   let summary = "Bitwise left shift operation";
   let description = [{
     With the `bitwise_left_shift` operation the bitwise operator <<
@@ -139,7 +143,7 @@ def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> {
   }];
 }
 
-def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> {
+def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", [CExpression]> {
   let summary = "Bitwise not operation";
   let description = [{
     With the `bitwise_not` operation the bitwise operator ~ (not) can
@@ -157,7 +161,7 @@ def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> {
   }];
 }
 
-def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> {
+def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", [CExpression]> {
   let summary = "Bitwise or operation";
   let description = [{
     With the `bitwise_or` operation the bitwise operator | (or)
@@ -175,7 +179,8 @@ def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> {
   }];
 }
 
-def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> {
+def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift",
+    [CExpression]> {
   let summary = "Bitwise right shift operation";
   let description = [{
     With the `bitwise_right_shift` operation the bitwise operator >>
@@ -193,7 +198,7 @@ def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> {
   }];
 }
 
-def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> {
+def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", [CExpression]> {
   let summary = "Bitwise xor operation";
   let description = [{
     With the `bitwise_xor` operation the bitwise operator ^ (xor)
@@ -211,7 +216,7 @@ def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> {
   }];
 }
 
-def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> {
+def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
   let summary = "Opaque call operation";
   let description = [{
     The `call_opaque` operation represents a C++ function call. The callee
@@ -257,10 +262,10 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_CastOp : EmitC_Op<"cast", [
-    DeclareOpInterfaceMethods<CastOpInterface>,
-    SameOperandsAndResultShape
-  ]> {
+def EmitC_CastOp : EmitC_Op<"cast",
+    [CExpression,
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     SameOperandsAndResultShape]> {
   let summary = "Cast operation";
   let description = [{
     The `cast` operation performs an explicit type conversion and is emitted
@@ -284,7 +289,7 @@ def EmitC_CastOp : EmitC_Op<"cast", [
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 }
 
-def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
+def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
   let summary = "Comparison operation";
   let description = [{
     With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> 
@@ -355,7 +360,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   let hasVerifier = 1;
 }
 
-def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
+def EmitC_DivOp : EmitC_BinaryOp<"div", [CExpression]> {
   let summary = "Division operation";
   let description = [{
     With the `div` operation the arithmetic operator / (division) can
@@ -409,9 +414,8 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
     int32_t v7 = foo(v1 + v2) * (v3 + v4);
     ```
 
-    The operations allowed within expression body are `emitc.add`,
-    `emitc.apply`, `emitc.call_opaque`, `emitc.cast`, `emitc.cmp`, `emitc.div`,
-    `emitc.mul`, `emitc.rem`, and `emitc.sub`.
+    The operations allowed within expression body are EmitC operations with the
+    CExpression trait.
 
     When specified, the optional `do_not_inline` indicates that the expression is
     to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
@@ -427,14 +431,9 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
   let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
 
   let extraClassDeclaration = [{
-    static bool isCExpression(Operation &op) {
-      return isa<emitc::AddOp, emitc::ApplyOp, emitc::CallOpaqueOp,
-                 emitc::CastOp, emitc::CmpOp, emitc::DivOp, emitc::MulOp,
-                 emitc::RemOp, emitc::SubOp>(op);
-    }
     bool hasSideEffects() {
       auto predicate = [](Operation &op) {
-        assert(isCExpression(op) && "Expected a C expression");
+        assert(op.hasTrait<OpTrait::emitc::CExpression>() && "Expected a C expression");
         // Conservatively assume calls to read and write memory.
         if (isa<emitc::CallOpaqueOp>(op))
           return true;
@@ -518,7 +517,7 @@ def EmitC_ForOp : EmitC_Op<"for",
 }
 
 def EmitC_CallOp : EmitC_Op<"call",
-    [CallOpInterface,
+    [CallOpInterface, CExpression,
      DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "call operation";
   let description = [{
@@ -774,7 +773,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
   let assemblyFormat = "$value attr-dict `:` type($result)";
 }
 
-def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> {
+def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", [CExpression]> {
   let summary = "Logical and operation";
   let description = [{
     With the `logical_and` operation the logical operator && (and) can
@@ -795,7 +794,7 @@ def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> {
+def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", [CExpression]> {
   let summary = "Logical not operation";
   let description = [{
     With the `logical_not` operation the logical operator ! (negation) can
@@ -816,7 +815,7 @@ def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> {
+def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
   let summary = "Logical or operation";
   let description = [{
     With the `logical_or` operation the logical operator || (inclusive or)
@@ -837,7 +836,7 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
+def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
   let summary = "Multiplication operation";
   let description = [{
     With the `mul` operation the arithmetic operator * (multiplication) can
@@ -861,7 +860,7 @@ def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
   let results = (outs FloatIntegerIndexOrOpaqueType);
 }
 
-def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
+def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> {
   let summary = "Remainder operation";
   let description = [{
     With the `rem` operation the arithmetic operator % (remainder) can
@@ -883,7 +882,7 @@ def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
   let results = (outs IntegerIndexOrOpaqueType);
 }
 
-def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
+def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
   let summary = "Subtraction operation";
   let description = [{
     With the `sub` operation the arithmetic operator - (subtraction) can
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
new file mode 100644
index 00000000000000..c1602dfce4b484
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
@@ -0,0 +1,30 @@
+//===- EmitCTraits.h - EmitC trait definitions ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares C++ classes for some of the traits used in the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
+#define MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace emitc {
+
+template <typename ConcreteType>
+class CExpression : public TraitBase<ConcreteType, CExpression> {};
+
+} // namespace emitc
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4df8149b94c95f..07ee1d394287b9 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -244,7 +245,7 @@ LogicalResult ExpressionOp::verify() {
     return emitOpError("requires yielded type to match return type");
 
   for (Operation &op : region.front().without_terminator()) {
-    if (!isCExpression(op))
+    if (!op.hasTrait<OpTrait::emitc::CExpression>())
       return emitOpError("contains an unsupported operation");
     if (op.getNumResults() != 1)
       return emitOpError("requires exactly one result for each operation");
diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
index 21212155ffb22f..5b03f81b305fd5 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
@@ -36,7 +36,7 @@ struct FormExpressionsPass
     // Wrap each C operator op with an expression op.
     OpBuilder builder(context);
     auto matchFun = [&](Operation *op) {
-      if (emitc::ExpressionOp::isCExpression(*op))
+      if (op->hasTrait<OpTrait::emitc::CExpression>())
         createExpression(op, builder);
     };
     rootOp->walk(matchFun);
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 88b691b50f325d..87350ecdceaaac 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -16,7 +16,8 @@ namespace mlir {
 namespace emitc {
 
 ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
-  assert(ExpressionOp::isCExpression(*op) && "Expected a C expression");
+  assert(op->hasTrait<OpTrait::emitc::CExpression>() &&
+         "Expected a C expression");
 
   // Create an expression yielding the value returned by op.
   assert(op->getNumResults() == 1 && "Expected exactly one result");



More information about the Mlir-commits mailing list