[Mlir-commits] [mlir] [mlir][emitc] Refactor emitc.apply op (PR #72569)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 16 12:54:23 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc
Author: Gil Rapaport (aniragil)
<details>
<summary>Changes</summary>
The emitc.apply op models both C's address-taking and dereferencing operators
using an attribute to select the concrete opcode. This patch replaces
emitc.apply with a pair of emitc.address_of and emitc.dereference ops.
Unlike emitc.apply, which supported taking the address of the C variables
expected to hold SSA values, the new emitc.address_of op limits address taking
to the C variables modeled by the dialect by requiring its operand to be
defined by an emitc.variable op.
---
Full diff: https://github.com/llvm/llvm-project/pull/72569.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+32-19)
- (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+28-13)
- (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+30-16)
- (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+7-15)
- (modified) mlir/test/Dialect/EmitC/ops.mlir (+10-3)
- (modified) mlir/test/Target/Cpp/common-cpp.mlir (+4-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 2edeb6f8a9cf01e..53cd708e04aa48d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -63,31 +63,23 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
let hasVerifier = 1;
}
-def EmitC_ApplyOp : EmitC_Op<"apply", []> {
- let summary = "Apply operation";
+def EmitC_AddressOfOp : EmitC_Op<"address_of", []> {
+ let summary = "Address operation";
let description = [{
- With the `apply` operation the operators & (address of) and * (contents of)
- can be applied to a single operand.
+ This operation models the C & (address of) operator for a single operand which
+ must be an emitc.variable. It returns an emitc pointer to the variable.
Example:
```mlir
// Custom form of applying the & operator.
- %0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
-
- // Generic form of the same operation.
- %0 = "emitc.apply"(%arg0) {applicableOperator = "&"}
- : (i32) -> !emitc.ptr<i32>
-
+ %0 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
```
}];
- let arguments = (ins
- Arg<StrAttr, "the operator to apply">:$applicableOperator,
- AnyType:$operand
- );
- let results = (outs AnyType:$result);
+ let arguments = (ins AnyType:$var);
+ let results = (outs EmitC_PointerType:$result);
let assemblyFormat = [{
- $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
+ $var attr-dict `:` functional-type($var, $result)
}];
let hasVerifier = 1;
}
@@ -222,6 +214,27 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let hasVerifier = 1;
}
+def EmitC_DereferenceOp : EmitC_Op<"dereference", []> {
+ let summary = "Dereference operation";
+ let description = [{
+ This operation models the C * (dereference) operator for a single operand which
+ must be of !emitc.ptr<> type. It returns the value pointed to by the pointer.
+
+ Example:
+
+ ```mlir
+ // Custom form of applying the & operator.
+ %0 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> i32
+ ```
+ }];
+ let arguments = (ins EmitC_PointerType:$pointer);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{
+ $pointer attr-dict `:` functional-type($pointer, $result)
+ }];
+ let hasVerifier = 1;
+}
+
def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let summary = "Division operation";
let description = [{
@@ -448,12 +461,12 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
Since folding is not supported, it can be used with pointers.
As an example, it is valid to create pointers to `variable` operations
- by using `apply` operations and pass these to a `call` operation.
+ by using `address_of` operations and pass these to a `call` operation.
```mlir
%0 = "emitc.variable"() {value = 0 : i32} : () -> i32
%1 = "emitc.variable"() {value = 0 : i32} : () -> i32
- %2 = emitc.apply "&"(%0) : (i32) -> !emitc.ptr<i32>
- %3 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
+ %2 = emitc.address_of %0 : (i32) -> !emitc.ptr<i32>
+ %3 = emitc.address_of %1 : (i32) -> !emitc.ptr<i32>
emitc.call "write"(%2, %3) : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> ()
```
}];
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index d06381b7ddad3dc..cb7bf857e27ae60 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -72,23 +72,21 @@ LogicalResult AddOp::verify() {
}
//===----------------------------------------------------------------------===//
-// ApplyOp
+// AddressOfOp
//===----------------------------------------------------------------------===//
-LogicalResult ApplyOp::verify() {
- StringRef applicableOperatorStr = getApplicableOperator();
-
- // Applicable operator must not be empty.
- if (applicableOperatorStr.empty())
- return emitOpError("applicable operator must not be empty");
+LogicalResult AddressOfOp::verify() {
+ Value variable = getVar();
+ auto variableDef = dyn_cast_if_present<VariableOp>(variable.getDefiningOp());
+ if (!variableDef)
+ return emitOpError() << "requires operand to be a variable";
- // Only `*` and `&` are supported.
- if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
- return emitOpError("applicable operator is illegal");
+ Type variableType = variable.getType();
+ emitc::PointerType resultType = getResult().getType();
+ Type pointeeType = resultType.getPointee();
- Operation *op = getOperand().getDefiningOp();
- if (op && dyn_cast<ConstantOp>(op))
- return emitOpError("cannot apply to constant");
+ if (variableType != pointeeType)
+ return emitOpError("requires variable to be of type pointed to by result");
return success();
}
@@ -189,6 +187,23 @@ LogicalResult emitc::ConstantOp::verify() {
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// DereferenceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DereferenceOp::verify() {
+ auto pointer = getPointer();
+ emitc::PointerType pointerType = pointer.getType();
+ Type pointeeType = pointerType.getPointee();
+ Type resultType = getResult().getType();
+
+ if (pointeeType != resultType)
+ return emitOpError()
+ << "requires result to be of type pointed to by operand";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 291624c5480318d..1946d4f5a6ec11b 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -213,6 +213,19 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
return emitter.emitAttribute(operation->getLoc(), value);
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::AddressOfOp addressOfOp) {
+ raw_ostream &os = emitter.ostream();
+ Operation &op = *addressOfOp.getOperation();
+
+ if (failed(emitter.emitAssignPrefix(op)))
+ return failure();
+ os << "&";
+ os << emitter.getOrCreateName(addressOfOp.getOperand());
+
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
@@ -461,30 +474,30 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter,
- emitc::ApplyOp applyOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
raw_ostream &os = emitter.ostream();
- Operation &op = *applyOp.getOperation();
+ Operation &op = *castOp.getOperation();
if (failed(emitter.emitAssignPrefix(op)))
return failure();
- os << applyOp.getApplicableOperator();
- os << emitter.getOrCreateName(applyOp.getOperand());
+ os << "(";
+ if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
+ return failure();
+ os << ") ";
+ os << emitter.getOrCreateName(castOp.getOperand());
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::DereferenceOp dereferenceOp) {
raw_ostream &os = emitter.ostream();
- Operation &op = *castOp.getOperation();
+ Operation &op = *dereferenceOp.getOperation();
if (failed(emitter.emitAssignPrefix(op)))
return failure();
- os << "(";
- if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
- return failure();
- os << ") ";
- os << emitter.getOrCreateName(castOp.getOperand());
+ os << "*";
+ os << emitter.getOrCreateName(dereferenceOp.getOperand());
return success();
}
@@ -949,10 +962,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
- .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
- emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
- emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
- emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+ .Case<emitc::AddOp, emitc::AddressOfOp, emitc::AssignOp,
+ emitc::CallOp, emitc::CastOp, emitc::CmpOp, emitc::ConstantOp,
+ emitc::DereferenceOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
+ emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 53d88adf4305ff8..04a2223c4448e10 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -72,26 +72,18 @@ func.func @dense_template_argument(%arg : i32) {
// -----
-func.func @empty_operator(%arg : i32) {
- // expected-error @+1 {{'emitc.apply' op applicable operator must not be empty}}
- %2 = emitc.apply ""(%arg) : (i32) -> !emitc.ptr<i32>
- return
-}
-
-// -----
-
-func.func @illegal_operator(%arg : i32) {
- // expected-error @+1 {{'emitc.apply' op applicable operator is illegal}}
- %2 = emitc.apply "+"(%arg) : (i32) -> !emitc.ptr<i32>
+func.func @illegal_address_of_operand() {
+ %1 = "emitc.constant"(){value = 42: i32} : () -> i32
+ // expected-error @+1 {{'emitc.address_of' op requires operand to be a variable}}
+ %2 = emitc.address_of %1 : (i32) -> !emitc.ptr<i32>
return
}
// -----
-func.func @illegal_operand() {
- %1 = "emitc.constant"(){value = 42: i32} : () -> i32
- // expected-error @+1 {{'emitc.apply' op cannot apply to constant}}
- %2 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
+func.func @illegal_dereference_operand(%arg0 : !emitc.ptr<i32>) {
+ // expected-error @+1 {{'emitc.dereference' op requires result to be of type pointed to by operand}}
+ %2 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> (f32)
return
}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 6c8398680980466..50358229b787b96 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -25,9 +25,10 @@ func.func @c() {
return
}
-func.func @a(%arg0: i32, %arg1: i32) {
- %1 = "emitc.apply"(%arg0) {applicableOperator = "&"} : (i32) -> !emitc.ptr<i32>
- %2 = emitc.apply "&"(%arg1) : (i32) -> !emitc.ptr<i32>
+func.func @a() {
+ %arg0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ %1 = "emitc.address_of"(%arg0) : (i32) -> !emitc.ptr<i32>
+ %2 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
return
}
@@ -47,6 +48,12 @@ func.func @div_int(%arg0: i32, %arg1: i32) {
return
}
+func.func @dereference(%arg0: !emitc.ptr<i32>) {
+ %1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> (i32)
+ %2 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> (i32)
+ return
+}
+
func.func @div_float(%arg0: f32, %arg1: f32) {
%1 = "emitc.div" (%arg0, %arg1) : (f32, f32) -> f32
return
diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir
index 252f5e214840da5..b1280a7345328d8 100644
--- a/mlir/test/Target/Cpp/common-cpp.mlir
+++ b/mlir/test/Target/Cpp/common-cpp.mlir
@@ -82,10 +82,11 @@ func.func @opaque_types(%arg0: !emitc.opaque<"bool">, %arg1: !emitc.opaque<"char
return %2 : !emitc.opaque<"status_t">
}
-func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
+func.func @apply() -> !emitc.ptr<i32> {
+ %arg0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
// CHECK: int32_t* [[V2]] = &[[V1]];
- %0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
+ %0 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
// CHECK: int32_t [[V3]] = *[[V2]];
- %1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
+ %1 = emitc.dereference %0 : (!emitc.ptr<i32>) -> (i32)
return %0 : !emitc.ptr<i32>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/72569
More information about the Mlir-commits
mailing list