[Mlir-commits] [mlir] 0f1ac5e - [mlir][emitc] Add add and sub operations
Marius Brehler
llvmlistbot at llvm.org
Fri Jun 23 05:27:11 PDT 2023
Author: Marius Brehler
Date: 2023-06-23T12:15:06Z
New Revision: 0f1ac5e110cc2d4d5575238b52382253ad7368b0
URL: https://github.com/llvm/llvm-project/commit/0f1ac5e110cc2d4d5575238b52382253ad7368b0
DIFF: https://github.com/llvm/llvm-project/commit/0f1ac5e110cc2d4d5575238b52382253ad7368b0.diff
LOG: [mlir][emitc] Add add and sub operations
This adds operations for binary additive operators to EmitC. The input
arguments to these ops can be EmitC pointers and thus the operations can
be used for pointer arithmetic.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D149963
Added:
mlir/test/Target/Cpp/arithmetic_operators.mlir
Modified:
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Dialect/EmitC/invalid_ops.mlir
mlir/test/Dialect/EmitC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index a53dea49b8adc..93cffcd891cb8 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -27,6 +27,37 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class EmitC_Op<string mnemonic, list<Trait> traits = []>
: Op<EmitC_Dialect, mnemonic, traits>;
+// Base class for binary arithmetic operations.
+class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
+ EmitC_Op<mnemonic, traits> {
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType);
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+
+ let hasVerifier = 1;
+}
+
+def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> {
+ let summary = "Addition operation";
+ let description = [{
+ With the `add` operation the arithmetic operator + (addition) can
+ be applied.
+
+ Example:
+
+ ```mlir
+ // Custom form of the addition operation.
+ %0 = emitc.add %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.add %arg2, %arg3 : (!emitc.ptr<f32>, i32) -> !emitc.ptr<f32>
+ ```
+ ```c++
+ // Code emitted for the operations above.
+ int32_t v5 = v1 + v2;
+ float* v6 = v3 + v4;
+ ```
+ }];
+}
+
def EmitC_ApplyOp : EmitC_Op<"apply", []> {
let summary = "Apply operation";
let description = [{
@@ -175,6 +206,30 @@ def EmitC_IncludeOp
let hasCustomAssemblyFormat = 1;
}
+def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> {
+ let summary = "Subtraction operation";
+ let description = [{
+ With the `sub` operation the arithmetic operator - (subtraction) can
+ be applied.
+
+ Example:
+
+ ```mlir
+ // Custom form of the substraction operation.
+ %0 = emitc.sub %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.sub %arg2, %arg3 : (!emitc.ptr<f32>, i32) -> !emitc.ptr<f32>
+ %2 = emitc.sub %arg4, %arg5 : (!emitc.ptr<i32>, !emitc.ptr<i32>)
+ -> !emitc.opaque<"ptr
diff _t">
+ ```
+ ```c++
+ // Code emitted for the operations above.
+ int32_t v7 = v1 - v2;
+ float* v8 = v3 - v4;
+ ptr
diff _t v9 = v5 - v6;
+ ```
+ }];
+}
+
def EmitC_VariableOp : EmitC_Op<"variable", []> {
let summary = "Variable operation";
let description = [{
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index fc2e17a209705..a99a7d2b72f54 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -44,6 +44,27 @@ Operation *EmitCDialect::materializeConstant(OpBuilder &builder,
return builder.create<emitc::ConstantOp>(loc, type, value);
}
+//===----------------------------------------------------------------------===//
+// AddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AddOp::verify() {
+ Type lhsType = getLhs().getType();
+ Type rhsType = getRhs().getType();
+
+ if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>())
+ return emitOpError("requires that at most one operand is a pointer");
+
+ if ((lhsType.isa<emitc::PointerType>() &&
+ !rhsType.isa<IntegerType, emitc::OpaqueType>()) ||
+ (rhsType.isa<emitc::PointerType>() &&
+ !lhsType.isa<IntegerType, emitc::OpaqueType>()))
+ return emitOpError("requires that one operand is an integer or of opaque "
+ "type if the other is a pointer");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ApplyOp
//===----------------------------------------------------------------------===//
@@ -178,6 +199,31 @@ ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
+//===----------------------------------------------------------------------===//
+// SubOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SubOp::verify() {
+ Type lhsType = getLhs().getType();
+ Type rhsType = getRhs().getType();
+ Type resultType = getResult().getType();
+
+ if (rhsType.isa<emitc::PointerType>() && !lhsType.isa<emitc::PointerType>())
+ return emitOpError("rhs can only be a pointer if lhs is a pointer");
+
+ if (lhsType.isa<emitc::PointerType>() &&
+ !rhsType.isa<IntegerType, emitc::OpaqueType, emitc::PointerType>())
+ return emitOpError("requires that rhs is an integer, pointer or of opaque "
+ "type if lhs is a pointer");
+
+ if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>() &&
+ !resultType.isa<IntegerType, emitc::OpaqueType>())
+ return emitOpError("requires that the result is an integer or of opaque "
+ "type if lhs and rhs are pointers");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// VariableOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index b7eb1f07f2d9b..a1f6886a8ae4d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -246,6 +246,32 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
+static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
+ Operation *operation,
+ StringRef binaryArithOperator) {
+ raw_ostream &os = emitter.ostream();
+
+ if (failed(emitter.emitAssignPrefix(*operation)))
+ return failure();
+ os << emitter.getOrCreateName(operation->getOperand(0));
+ os << " " << binaryArithOperator;
+ os << " " << emitter.getOrCreateName(operation->getOperand(1));
+
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
+ Operation *operation = addOp.getOperation();
+
+ return printBinaryArithOperation(emitter, operation, "+");
+}
+
+static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
+ Operation *operation = subOp.getOperation();
+
+ return printBinaryArithOperation(emitter, operation, "-");
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
cf::BranchOp branchOp) {
raw_ostream &os = emitter.ostream();
@@ -930,8 +956,9 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
- .Case<emitc::ApplyOp, emitc::CallOp, emitc::CastOp, emitc::ConstantOp,
- emitc::IncludeOp, emitc::VariableOp>(
+ .Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
+ emitc::ConstantOp, emitc::IncludeOp, emitc::SubOp,
+ emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 377c84ef90dbd..2c45d49a7e371 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -118,3 +118,52 @@ func.func @cast_tensor(%arg : tensor<f32>) {
%1 = emitc.cast %arg: tensor<f32> to tensor<f32>
return
}
+
+// -----
+
+func.func @add_two_pointers(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
+ // expected-error @+1 {{'emitc.add' op requires that at most one operand is a pointer}}
+ %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.ptr<f32>
+ return
+}
+
+// -----
+
+func.func @add_pointer_float(%arg0: !emitc.ptr<f32>, %arg1: f32) {
+ // expected-error @+1 {{'emitc.add' op requires that one operand is an integer or of opaque type if the other is a pointer}}
+ %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, f32) -> !emitc.ptr<f32>
+ return
+}
+
+// -----
+
+func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
+ // expected-error @+1 {{'emitc.add' op requires that one operand is an integer or of opaque type if the other is a pointer}}
+ %1 = "emitc.add" (%arg0, %arg1) : (f32, !emitc.ptr<f32>) -> !emitc.ptr<f32>
+ return
+}
+
+// -----
+
+func.func @sub_int_pointer(%arg0: i32, %arg1: !emitc.ptr<f32>) {
+ // expected-error @+1 {{'emitc.sub' op rhs can only be a pointer if lhs is a pointer}}
+ %1 = "emitc.sub" (%arg0, %arg1) : (i32, !emitc.ptr<f32>) -> !emitc.ptr<f32>
+ return
+}
+
+
+// -----
+
+func.func @sub_pointer_float(%arg0: !emitc.ptr<f32>, %arg1: f32) {
+ // expected-error @+1 {{'emitc.sub' op requires that rhs is an integer, pointer or of opaque type if lhs is a pointer}}
+ %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr<f32>, f32) -> !emitc.ptr<f32>
+ return
+}
+
+// -----
+
+func.func @sub_pointer_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
+ // expected-error @+1 {{'emitc.sub' op requires that the result is an integer or of opaque type if lhs and rhs are pointers}}
+ %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.ptr<f32>
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index b682aac381da6..a1d226b379430 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -30,3 +30,27 @@ func.func @a(%arg0: i32, %arg1: i32) {
%2 = emitc.apply "&"(%arg1) : (i32) -> !emitc.ptr<i32>
return
}
+
+func.func @add_int(%arg0: i32, %arg1: i32) {
+ %1 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32
+ return
+}
+
+func.func @add_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<"unsigned int">) {
+ %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, i32) -> !emitc.ptr<f32>
+ %2 = "emitc.add" (%arg0, %arg2) : (!emitc.ptr<f32>, !emitc.opaque<"unsigned int">) -> !emitc.ptr<f32>
+ return
+}
+
+func.func @sub_int(%arg0: i32, %arg1: i32) {
+ %1 = "emitc.sub" (%arg0, %arg1) : (i32, i32) -> i32
+ return
+}
+
+func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<"unsigned int">, %arg3: !emitc.ptr<f32>) {
+ %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr<f32>, i32) -> !emitc.ptr<f32>
+ %2 = "emitc.sub" (%arg0, %arg2) : (!emitc.ptr<f32>, !emitc.opaque<"unsigned int">) -> !emitc.ptr<f32>
+ %3 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.opaque<"ptr
diff _t">
+ %4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> i32
+ return
+}
diff --git a/mlir/test/Target/Cpp/arithmetic_operators.mlir b/mlir/test/Target/Cpp/arithmetic_operators.mlir
new file mode 100644
index 0000000000000..0ce1af45ac9a7
--- /dev/null
+++ b/mlir/test/Target/Cpp/arithmetic_operators.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @add_int(%arg0: i32, %arg1: i32) {
+ %1 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32
+ return
+}
+// CHECK-LABEL: void add_int
+// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] + [[V1:[^ ]*]]
+
+func.func @add_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32) {
+ %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, i32) -> !emitc.ptr<f32>
+ return
+}
+// CHECK-LABEL: void add_pointer
+// CHECK-NEXT: float* [[V2:[^ ]*]] = [[V0:[^ ]*]] + [[V1:[^ ]*]]
+
+func.func @sub_int(%arg0: i32, %arg1: i32) {
+ %1 = "emitc.sub" (%arg0, %arg1) : (i32, i32) -> i32
+ return
+}
+// CHECK-LABEL: void sub_int
+// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] - [[V1:[^ ]*]]
+
+func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
+ %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.opaque<"ptr
diff _t">
+ return
+}
+// CHECK-LABEL: void sub_pointer
+// CHECK-NEXT: ptr
diff _t [[V2:[^ ]*]] = [[V0:[^ ]*]] - [[V1:[^ ]*]]
More information about the Mlir-commits
mailing list