[Mlir-commits] [mlir] adea7e7 - [mlir][emitc] Add comparison operation
Marius Brehler
llvmlistbot at llvm.org
Tue Aug 29 09:51:08 PDT 2023
Author: Simon Camphausen
Date: 2023-08-29T16:50:32Z
New Revision: adea7e7032adc9d3ba2f9e18aea3d1cfbd6514b2
URL: https://github.com/llvm/llvm-project/commit/adea7e7032adc9d3ba2f9e18aea3d1cfbd6514b2
DIFF: https://github.com/llvm/llvm-project/commit/adea7e7032adc9d3ba2f9e18aea3d1cfbd6514b2.diff
LOG: [mlir][emitc] Add comparison operation
This adds a comparison operation to EmitC which supports ==, !=, <=, <, >=, >, <=>.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D158180
Added:
mlir/test/Target/Cpp/comparison_operators.mlir
Modified:
mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Dialect/EmitC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
index 09a9f7a2ec1c59..ac8c651cdced89 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
@@ -2,6 +2,8 @@ add_mlir_dialect(EmitC emitc)
add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
+mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
+mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIREmitCAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 0acaa85139508f..b3c1170eefdab9 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -21,6 +21,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
+#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index a10b64da436140..9e0880089c9f87 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -27,8 +27,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class EmitC_Op<string mnemonic, list<Trait> traits = []>
: Op<EmitC_Dialect, mnemonic, traits>;
-// Base class for binary arithmetic operations.
-class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
+// Base class for binary operations.
+class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let results = (outs AnyType);
@@ -39,7 +39,7 @@ class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
-def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> {
+def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
let summary = "Addition operation";
let description = [{
With the `add` operation the arithmetic operator + (addition) can
@@ -150,6 +150,37 @@ def EmitC_CastOp : EmitC_Op<"cast", [
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}
+def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
+ let summary = "Comparison operation";
+ let description = [{
+ With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
+ can be applied.
+
+ Example:
+ ```mlir
+ // Custom form of the cmp operation.
+ %0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1
+ %1 = emitc.cmp lt, %arg2, %arg3 :
+ (
+ !emitc.opaque<"std::valarray<float>">,
+ !emitc.opaque<"std::valarray<float>">
+ ) -> !emitc.opaque<"std::valarray<bool>">
+ ```
+ ```c++
+ // Code emitted for the operations above.
+ bool v5 = v1 == v2;
+ std::valarray<bool> v6 = v3 < v4;
+ ```
+ }];
+
+ let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
+ AnyType:$lhs,
+ AnyType:$rhs);
+ let results = (outs AnyType);
+
+ let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
+}
+
def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let summary = "Constant operation";
let description = [{
@@ -180,7 +211,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let hasVerifier = 1;
}
-def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> {
+def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let summary = "Division operation";
let description = [{
With the `div` operation the arithmetic operator / (division) can
@@ -248,7 +279,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
let assemblyFormat = "$value attr-dict `:` type($result)";
}
-def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
+def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
let summary = "Multiplication operation";
let description = [{
With the `mul` operation the arithmetic operator * (multiplication) can
@@ -272,7 +303,7 @@ def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}
-def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
+def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
let summary = "Remainder operation";
let description = [{
With the `rem` operation the arithmetic operator % (remainder) can
@@ -294,7 +325,7 @@ def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
let results = (outs IntegerIndexOrOpaqueType);
}
-def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> {
+def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
let summary = "Subtraction operation";
let description = [{
With the `sub` operation the arithmetic operator - (subtraction) can
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
index d69b1c20eaee5e..ae843e49c6c5ba 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
@@ -15,6 +15,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
//===----------------------------------------------------------------------===//
@@ -26,6 +27,20 @@ class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}
+def EmitC_CmpPredicateAttr : I64EnumAttr<
+ "CmpPredicate", "",
+ [
+ I64EnumAttrCase<"eq", 0>,
+ I64EnumAttrCase<"ne", 1>,
+ I64EnumAttrCase<"lt", 2>,
+ I64EnumAttrCase<"le", 3>,
+ I64EnumAttrCase<"gt", 4>,
+ I64EnumAttrCase<"ge", 5>,
+ I64EnumAttrCase<"three_way", 6>,
+ ]> {
+ let cppNamespace = "::mlir::emitc";
+}
+
def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
let summary = "An opaque attribute";
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ceb1e94ce8177d..0aac33e22cbc40 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -257,6 +257,12 @@ LogicalResult emitc::VariableOp::verify() {
#define GET_OP_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
+//===----------------------------------------------------------------------===//
+// EmitC Enums
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
+
//===----------------------------------------------------------------------===//
// EmitC Attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 396895d3b559da..c36e202f1b17e2 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -246,15 +246,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
-static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
- Operation *operation,
- StringRef binaryArithOperator) {
+static LogicalResult printBinaryOperation(CppEmitter &emitter,
+ Operation *operation,
+ StringRef binaryOperator) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitAssignPrefix(*operation)))
return failure();
os << emitter.getOrCreateName(operation->getOperand(0));
- os << " " << binaryArithOperator;
+ os << " " << binaryOperator;
os << " " << emitter.getOrCreateName(operation->getOperand(1));
return success();
@@ -263,31 +263,65 @@ static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
Operation *operation = addOp.getOperation();
- return printBinaryArithOperation(emitter, operation, "+");
+ return printBinaryOperation(emitter, operation, "+");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
Operation *operation = divOp.getOperation();
- return printBinaryArithOperation(emitter, operation, "/");
+ return printBinaryOperation(emitter, operation, "/");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
Operation *operation = mulOp.getOperation();
- return printBinaryArithOperation(emitter, operation, "*");
+ return printBinaryOperation(emitter, operation, "*");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
Operation *operation = remOp.getOperation();
- return printBinaryArithOperation(emitter, operation, "%");
+ return printBinaryOperation(emitter, operation, "%");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
Operation *operation = subOp.getOperation();
- return printBinaryArithOperation(emitter, operation, "-");
+ return printBinaryOperation(emitter, operation, "-");
+}
+
+static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
+ Operation *operation = cmpOp.getOperation();
+
+ StringRef binaryOperator;
+
+ switch (cmpOp.getPredicate()) {
+ case emitc::CmpPredicate::eq:
+ binaryOperator = "==";
+ break;
+ case emitc::CmpPredicate::ne:
+ binaryOperator = "!=";
+ break;
+ case emitc::CmpPredicate::lt:
+ binaryOperator = "<";
+ break;
+ case emitc::CmpPredicate::le:
+ binaryOperator = "<=";
+ break;
+ case emitc::CmpPredicate::gt:
+ binaryOperator = ">";
+ break;
+ case emitc::CmpPredicate::ge:
+ binaryOperator = ">=";
+ break;
+ case emitc::CmpPredicate::three_way:
+ binaryOperator = "<=>";
+ break;
+ default:
+ return cmpOp.emitError("unhandled comparison predicate");
+ }
+
+ return printBinaryOperation(emitter, operation, binaryOperator);
}
static LogicalResult printOperation(CppEmitter &emitter,
@@ -977,8 +1011,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
- emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, emitc::MulOp,
- emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+ emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp,
+ emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 4f08601d6ac403..279fe13229c594 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -79,3 +79,21 @@ func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<
%4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> i32
return
}
+
+func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
+ %1 = "emitc.cmp" (%arg0, %arg0) {predicate = 0} : (i32, i32) -> i1
+ %2 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1
+ %3 = "emitc.cmp" (%arg1, %arg1) {predicate = 1} : (f32, f32) -> i1
+ %4 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1
+ %5 = "emitc.cmp" (%arg2, %arg2) {predicate = 2} : (i64, i64) -> i1
+ %6 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1
+ %7 = "emitc.cmp" (%arg3, %arg3) {predicate = 3} : (f64, f64) -> i1
+ %8 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
+ %9 = "emitc.cmp" (%arg4, %arg4) {predicate = 4} : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
+ %10 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
+ %11 = "emitc.cmp" (%arg5, %arg5) {predicate = 5} : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
+ %12 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
+ %13 = "emitc.cmp" (%arg6, %arg6) {predicate = 6} : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
+ %14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
+ return
+}
diff --git a/mlir/test/Target/Cpp/comparison_operators.mlir b/mlir/test/Target/Cpp/comparison_operators.mlir
new file mode 100644
index 00000000000000..a3751214b92433
--- /dev/null
+++ b/mlir/test/Target/Cpp/comparison_operators.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
+ %1 = emitc.cmp eq, %arg0, %arg2 : (i32, i64) -> i1
+ %2 = emitc.cmp ne, %arg1, %arg3 : (f32, f64) -> i1
+ %3 = emitc.cmp lt, %arg2, %arg4 : (i64, !emitc.opaque<"unsigned">) -> !emitc.opaque<"int">
+ %4 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
+ %5 = emitc.cmp gt, %arg6, %arg4 : (!emitc.opaque<"custom">, !emitc.opaque<"unsigned">) -> !emitc.opaque<"custom">
+ %6 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
+ %7 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
+
+ return
+}
+// CHECK-LABEL: void cmp
+// CHECK-NEXT: bool [[V7:[^ ]*]] = [[V0:[^ ]*]] == [[V2:[^ ]*]];
+// CHECK-NEXT: bool [[V8:[^ ]*]] = [[V1:[^ ]*]] != [[V3:[^ ]*]];
+// CHECK-NEXT: int [[V9:[^ ]*]] = [[V2]] < [[V4:[^ ]*]];
+// CHECK-NEXT: bool [[V10:[^ ]*]] = [[V3]] <= [[V3]];
+// CHECK-NEXT: custom [[V11:[^ ]*]] = [[V6:[^ ]*]] > [[V4]];
+// CHECK-NEXT: std::valarray<bool> [[V12:[^ ]*]] = [[V5:[^ ]*]] >= [[V5]];
+// CHECK-NEXT: custom [[V13:[^ ]*]] = [[V6]] <=> [[V6]];
More information about the Mlir-commits
mailing list