[Mlir-commits] [mlir] adea7e7 - [mlir][emitc] Add comparison operation

Marius Brehler llvmlistbot at llvm.org
Tue Aug 29 09:51:08 PDT 2023


Author: Simon Camphausen
Date: 2023-08-29T16:50:32Z
New Revision: adea7e7032adc9d3ba2f9e18aea3d1cfbd6514b2

URL: https://github.com/llvm/llvm-project/commit/adea7e7032adc9d3ba2f9e18aea3d1cfbd6514b2
DIFF: https://github.com/llvm/llvm-project/commit/adea7e7032adc9d3ba2f9e18aea3d1cfbd6514b2.diff

LOG: [mlir][emitc] Add comparison operation

This adds a comparison operation to EmitC which supports ==, !=, <=, <, >=, >, <=>.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D158180

Added: 
    mlir/test/Target/Cpp/comparison_operators.mlir

Modified: 
    mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
    mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/lib/Target/Cpp/TranslateToCpp.cpp
    mlir/test/Dialect/EmitC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
index 09a9f7a2ec1c59..ac8c651cdced89 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
@@ -2,6 +2,8 @@ add_mlir_dialect(EmitC emitc)
 add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc)
 
 set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
+mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
+mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIREmitCAttributesIncGen)

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 0acaa85139508f..b3c1170eefdab9 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -21,6 +21,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
+#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index a10b64da436140..9e0880089c9f87 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -27,8 +27,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 class EmitC_Op<string mnemonic, list<Trait> traits = []>
     : Op<EmitC_Dialect, mnemonic, traits>;
 
-// Base class for binary arithmetic operations.
-class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
+// Base class for binary operations.
+class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
     EmitC_Op<mnemonic, traits> {
   let arguments = (ins AnyType:$lhs, AnyType:$rhs);
   let results = (outs AnyType);
@@ -39,7 +39,7 @@ class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
 def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
 def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
 
-def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> {
+def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
   let summary = "Addition operation";
   let description = [{
     With the `add` operation the arithmetic operator + (addition) can
@@ -150,6 +150,37 @@ def EmitC_CastOp : EmitC_Op<"cast", [
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 }
 
+def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
+  let summary = "Comparison operation";
+  let description = [{
+    With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> 
+    can be applied.
+
+    Example:
+    ```mlir
+    // Custom form of the cmp operation.
+    %0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1
+    %1 = emitc.cmp lt, %arg2, %arg3 : 
+        (
+          !emitc.opaque<"std::valarray<float>">,
+          !emitc.opaque<"std::valarray<float>">
+        ) -> !emitc.opaque<"std::valarray<bool>">
+    ```
+    ```c++
+    // Code emitted for the operations above.
+    bool v5 = v1 == v2;
+    std::valarray<bool> v6 = v3 < v4;
+    ```
+  }];
+
+  let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
+                       AnyType:$lhs,
+                       AnyType:$rhs);
+  let results = (outs AnyType);
+
+  let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
+}
+
 def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   let summary = "Constant operation";
   let description = [{
@@ -180,7 +211,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   let hasVerifier = 1;
 }
 
-def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> {
+def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
   let summary = "Division operation";
   let description = [{
     With the `div` operation the arithmetic operator / (division) can
@@ -248,7 +279,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
   let assemblyFormat = "$value attr-dict `:` type($result)";
 }
 
-def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
+def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
   let summary = "Multiplication operation";
   let description = [{
     With the `mul` operation the arithmetic operator * (multiplication) can
@@ -272,7 +303,7 @@ def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
   let results = (outs FloatIntegerIndexOrOpaqueType);
 }
 
-def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
+def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
   let summary = "Remainder operation";
   let description = [{
     With the `rem` operation the arithmetic operator % (remainder) can
@@ -294,7 +325,7 @@ def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
   let results = (outs IntegerIndexOrOpaqueType);
 }
 
-def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> {
+def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
   let summary = "Subtraction operation";
   let description = [{
     With the `sub` operation the arithmetic operator - (subtraction) can

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
index d69b1c20eaee5e..ae843e49c6c5ba 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
@@ -15,6 +15,7 @@
 
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/EmitC/IR/EmitCBase.td"
 
 //===----------------------------------------------------------------------===//
@@ -26,6 +27,20 @@ class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
+def EmitC_CmpPredicateAttr : I64EnumAttr<
+    "CmpPredicate", "",
+    [
+      I64EnumAttrCase<"eq", 0>,
+      I64EnumAttrCase<"ne", 1>,
+      I64EnumAttrCase<"lt", 2>,
+      I64EnumAttrCase<"le", 3>,
+      I64EnumAttrCase<"gt", 4>,
+      I64EnumAttrCase<"ge", 5>,
+      I64EnumAttrCase<"three_way", 6>,
+    ]> {
+  let cppNamespace = "::mlir::emitc";
+}
+
 def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
   let summary = "An opaque attribute";
 

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ceb1e94ce8177d..0aac33e22cbc40 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -257,6 +257,12 @@ LogicalResult emitc::VariableOp::verify() {
 #define GET_OP_CLASSES
 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// EmitC Enums
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // EmitC Attributes
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 396895d3b559da..c36e202f1b17e2 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -246,15 +246,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return printConstantOp(emitter, operation, value);
 }
 
-static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
-                                               Operation *operation,
-                                               StringRef binaryArithOperator) {
+static LogicalResult printBinaryOperation(CppEmitter &emitter,
+                                          Operation *operation,
+                                          StringRef binaryOperator) {
   raw_ostream &os = emitter.ostream();
 
   if (failed(emitter.emitAssignPrefix(*operation)))
     return failure();
   os << emitter.getOrCreateName(operation->getOperand(0));
-  os << " " << binaryArithOperator;
+  os << " " << binaryOperator;
   os << " " << emitter.getOrCreateName(operation->getOperand(1));
 
   return success();
@@ -263,31 +263,65 @@ static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
   Operation *operation = addOp.getOperation();
 
-  return printBinaryArithOperation(emitter, operation, "+");
+  return printBinaryOperation(emitter, operation, "+");
 }
 
 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
   Operation *operation = divOp.getOperation();
 
-  return printBinaryArithOperation(emitter, operation, "/");
+  return printBinaryOperation(emitter, operation, "/");
 }
 
 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
   Operation *operation = mulOp.getOperation();
 
-  return printBinaryArithOperation(emitter, operation, "*");
+  return printBinaryOperation(emitter, operation, "*");
 }
 
 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
   Operation *operation = remOp.getOperation();
 
-  return printBinaryArithOperation(emitter, operation, "%");
+  return printBinaryOperation(emitter, operation, "%");
 }
 
 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
   Operation *operation = subOp.getOperation();
 
-  return printBinaryArithOperation(emitter, operation, "-");
+  return printBinaryOperation(emitter, operation, "-");
+}
+
+static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
+  Operation *operation = cmpOp.getOperation();
+
+  StringRef binaryOperator;
+
+  switch (cmpOp.getPredicate()) {
+  case emitc::CmpPredicate::eq:
+    binaryOperator = "==";
+    break;
+  case emitc::CmpPredicate::ne:
+    binaryOperator = "!=";
+    break;
+  case emitc::CmpPredicate::lt:
+    binaryOperator = "<";
+    break;
+  case emitc::CmpPredicate::le:
+    binaryOperator = "<=";
+    break;
+  case emitc::CmpPredicate::gt:
+    binaryOperator = ">";
+    break;
+  case emitc::CmpPredicate::ge:
+    binaryOperator = ">=";
+    break;
+  case emitc::CmpPredicate::three_way:
+    binaryOperator = "<=>";
+    break;
+  default:
+    return cmpOp.emitError("unhandled comparison predicate");
+  }
+
+  return printBinaryOperation(emitter, operation, binaryOperator);
 }
 
 static LogicalResult printOperation(CppEmitter &emitter,
@@ -977,8 +1011,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
               [&](auto op) { return printOperation(*this, op); })
           // EmitC ops.
           .Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
-                emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, emitc::MulOp,
-                emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+                emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp,
+                emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(

diff  --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 4f08601d6ac403..279fe13229c594 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -79,3 +79,21 @@ func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<
   %4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> i32
   return
 }
+
+func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
+  %1 = "emitc.cmp" (%arg0, %arg0) {predicate = 0} : (i32, i32) -> i1
+  %2 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1
+  %3 = "emitc.cmp" (%arg1, %arg1) {predicate = 1} : (f32, f32) -> i1
+  %4 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1
+  %5 = "emitc.cmp" (%arg2, %arg2) {predicate = 2} : (i64, i64) -> i1
+  %6 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1
+  %7 = "emitc.cmp" (%arg3, %arg3) {predicate = 3} : (f64, f64) -> i1
+  %8 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
+  %9 = "emitc.cmp" (%arg4, %arg4) {predicate = 4} : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
+  %10 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
+  %11 = "emitc.cmp" (%arg5, %arg5) {predicate = 5} : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
+  %12 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
+  %13 = "emitc.cmp" (%arg6, %arg6) {predicate = 6} : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
+  %14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
+  return
+}

diff  --git a/mlir/test/Target/Cpp/comparison_operators.mlir b/mlir/test/Target/Cpp/comparison_operators.mlir
new file mode 100644
index 00000000000000..a3751214b92433
--- /dev/null
+++ b/mlir/test/Target/Cpp/comparison_operators.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
+  %1 = emitc.cmp eq, %arg0, %arg2 : (i32, i64) -> i1
+  %2 = emitc.cmp ne, %arg1, %arg3 : (f32, f64) -> i1
+  %3 = emitc.cmp lt, %arg2, %arg4 : (i64, !emitc.opaque<"unsigned">) -> !emitc.opaque<"int">
+  %4 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
+  %5 = emitc.cmp gt, %arg6, %arg4 : (!emitc.opaque<"custom">, !emitc.opaque<"unsigned">) -> !emitc.opaque<"custom">
+  %6 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
+  %7 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
+  
+  return
+}
+// CHECK-LABEL: void cmp
+// CHECK-NEXT:  bool [[V7:[^ ]*]] = [[V0:[^ ]*]] == [[V2:[^ ]*]];
+// CHECK-NEXT:  bool [[V8:[^ ]*]] = [[V1:[^ ]*]] != [[V3:[^ ]*]];
+// CHECK-NEXT:  int [[V9:[^ ]*]] = [[V2]] < [[V4:[^ ]*]];
+// CHECK-NEXT:  bool [[V10:[^ ]*]] = [[V3]] <= [[V3]];
+// CHECK-NEXT:  custom [[V11:[^ ]*]] = [[V6:[^ ]*]] > [[V4]];
+// CHECK-NEXT:  std::valarray<bool> [[V12:[^ ]*]] = [[V5:[^ ]*]] >= [[V5]];
+// CHECK-NEXT:  custom [[V13:[^ ]*]] = [[V6]] <=> [[V6]];


        


More information about the Mlir-commits mailing list