[Mlir-commits] [mlir] [mlir][emitc] Refactor emitc.apply op (PR #72569)

Gil Rapaport llvmlistbot at llvm.org
Thu Nov 16 12:53:56 PST 2023


https://github.com/aniragil created https://github.com/llvm/llvm-project/pull/72569

The emitc.apply op models both C's address-taking and dereferencing operators
using an attribute to select the concrete opcode. This patch replaces
emitc.apply with a pair of emitc.address_of and emitc.dereference ops.

Unlike emitc.apply, which supported taking the address of the C variables
expected to hold SSA values, the new emitc.address_of op limits address taking
to the C variables modeled by the dialect by requiring its operand to be
defined by an emitc.variable op.


>From a955c18864768dcd0a60151493939d24e970b1b6 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Thu, 16 Nov 2023 18:08:57 +0200
Subject: [PATCH] [mlir][emitc] Refactor emitc.apply op

The emitc.apply op models both C's address-taking and dereferencing operators
using an attribute to select the concrete opcode. This patch replaces
emitc.apply with a pair of emitc.address_of and emitc.dereference ops.

Unlike emitc.apply, which supported taking the address of the C variables
expected to hold SSA values, the new emitc.address_of op limits address taking
to the C variables modeled by the dialect by requiring its operand to be
defined by an emitc.variable op.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 51 +++++++++++++--------
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 41 +++++++++++------
 mlir/lib/Target/Cpp/TranslateToCpp.cpp      | 46 ++++++++++++-------
 mlir/test/Dialect/EmitC/invalid_ops.mlir    | 22 +++------
 mlir/test/Dialect/EmitC/ops.mlir            | 13 ++++--
 mlir/test/Target/Cpp/common-cpp.mlir        |  7 +--
 6 files changed, 111 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 2edeb6f8a9cf01e..53cd708e04aa48d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -63,31 +63,23 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_ApplyOp : EmitC_Op<"apply", []> {
-  let summary = "Apply operation";
+def EmitC_AddressOfOp : EmitC_Op<"address_of", []> {
+  let summary = "Address operation";
   let description = [{
-    With the `apply` operation the operators & (address of) and * (contents of)
-    can be applied to a single operand.
+    This operation models the C & (address of) operator for a single operand which
+    must be an emitc.variable. It returns an emitc pointer to the variable.
 
     Example:
 
     ```mlir
     // Custom form of applying the & operator.
-    %0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
-
-    // Generic form of the same operation.
-    %0 = "emitc.apply"(%arg0) {applicableOperator = "&"}
-        : (i32) -> !emitc.ptr<i32>
-
+    %0 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
     ```
   }];
-  let arguments = (ins
-    Arg<StrAttr, "the operator to apply">:$applicableOperator,
-    AnyType:$operand
-  );
-  let results = (outs AnyType:$result);
+  let arguments = (ins AnyType:$var);
+  let results = (outs EmitC_PointerType:$result);
   let assemblyFormat = [{
-    $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
+    $var attr-dict `:` functional-type($var, $result)
   }];
   let hasVerifier = 1;
 }
@@ -222,6 +214,27 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   let hasVerifier = 1;
 }
 
+def EmitC_DereferenceOp : EmitC_Op<"dereference", []> {
+  let summary = "Dereference operation";
+  let description = [{
+    This operation models the C * (dereference) operator for a single operand which
+    must be of !emitc.ptr<> type. It returns the value pointed to by the pointer.
+
+    Example:
+
+    ```mlir
+    // Custom form of applying the & operator.
+    %0 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> i32
+    ```
+  }];
+  let arguments = (ins EmitC_PointerType:$pointer);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = [{
+    $pointer attr-dict `:` functional-type($pointer, $result)
+  }];
+  let hasVerifier = 1;
+}
+
 def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
   let summary = "Division operation";
   let description = [{
@@ -448,12 +461,12 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
 
     Since folding is not supported, it can be used with pointers.
     As an example, it is valid to create pointers to `variable` operations
-    by using `apply` operations and pass these to a `call` operation.
+    by using `address_of` operations and pass these to a `call` operation.
     ```mlir
     %0 = "emitc.variable"() {value = 0 : i32} : () -> i32
     %1 = "emitc.variable"() {value = 0 : i32} : () -> i32
-    %2 = emitc.apply "&"(%0) : (i32) -> !emitc.ptr<i32>
-    %3 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
+    %2 = emitc.address_of %0 : (i32) -> !emitc.ptr<i32>
+    %3 = emitc.address_of %1 : (i32) -> !emitc.ptr<i32>
     emitc.call "write"(%2, %3) : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> ()
     ```
   }];
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index d06381b7ddad3dc..cb7bf857e27ae60 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -72,23 +72,21 @@ LogicalResult AddOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// ApplyOp
+// AddressOfOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult ApplyOp::verify() {
-  StringRef applicableOperatorStr = getApplicableOperator();
-
-  // Applicable operator must not be empty.
-  if (applicableOperatorStr.empty())
-    return emitOpError("applicable operator must not be empty");
+LogicalResult AddressOfOp::verify() {
+  Value variable = getVar();
+  auto variableDef = dyn_cast_if_present<VariableOp>(variable.getDefiningOp());
+  if (!variableDef)
+    return emitOpError() << "requires operand to be a variable";
 
-  // Only `*` and `&` are supported.
-  if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
-    return emitOpError("applicable operator is illegal");
+  Type variableType = variable.getType();
+  emitc::PointerType resultType = getResult().getType();
+  Type pointeeType = resultType.getPointee();
 
-  Operation *op = getOperand().getDefiningOp();
-  if (op && dyn_cast<ConstantOp>(op))
-    return emitOpError("cannot apply to constant");
+  if (variableType != pointeeType)
+    return emitOpError("requires variable to be of type pointed to by result");
 
   return success();
 }
@@ -189,6 +187,23 @@ LogicalResult emitc::ConstantOp::verify() {
 
 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
+//===----------------------------------------------------------------------===//
+// DereferenceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DereferenceOp::verify() {
+  auto pointer = getPointer();
+  emitc::PointerType pointerType = pointer.getType();
+  Type pointeeType = pointerType.getPointee();
+  Type resultType = getResult().getType();
+
+  if (pointeeType != resultType)
+    return emitOpError()
+           << "requires result to be of type pointed to by operand";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 291624c5480318d..1946d4f5a6ec11b 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -213,6 +213,19 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
   return emitter.emitAttribute(operation->getLoc(), value);
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::AddressOfOp addressOfOp) {
+  raw_ostream &os = emitter.ostream();
+  Operation &op = *addressOfOp.getOperation();
+
+  if (failed(emitter.emitAssignPrefix(op)))
+    return failure();
+  os << "&";
+  os << emitter.getOrCreateName(addressOfOp.getOperand());
+
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::ConstantOp constantOp) {
   Operation *operation = constantOp.getOperation();
@@ -461,30 +474,30 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    emitc::ApplyOp applyOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
   raw_ostream &os = emitter.ostream();
-  Operation &op = *applyOp.getOperation();
+  Operation &op = *castOp.getOperation();
 
   if (failed(emitter.emitAssignPrefix(op)))
     return failure();
-  os << applyOp.getApplicableOperator();
-  os << emitter.getOrCreateName(applyOp.getOperand());
+  os << "(";
+  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
+    return failure();
+  os << ") ";
+  os << emitter.getOrCreateName(castOp.getOperand());
 
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::DereferenceOp dereferenceOp) {
   raw_ostream &os = emitter.ostream();
-  Operation &op = *castOp.getOperation();
+  Operation &op = *dereferenceOp.getOperation();
 
   if (failed(emitter.emitAssignPrefix(op)))
     return failure();
-  os << "(";
-  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
-    return failure();
-  os << ") ";
-  os << emitter.getOrCreateName(castOp.getOperand());
+  os << "*";
+  os << emitter.getOrCreateName(dereferenceOp.getOperand());
 
   return success();
 }
@@ -949,10 +962,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           .Case<cf::BranchOp, cf::CondBranchOp>(
               [&](auto op) { return printOperation(*this, op); })
           // EmitC ops.
-          .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
-                emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
-                emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
-                emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+          .Case<emitc::AddOp, emitc::AddressOfOp, emitc::AssignOp,
+                emitc::CallOp, emitc::CastOp, emitc::CmpOp, emitc::ConstantOp,
+                emitc::DereferenceOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
+                emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
+                emitc::VariableOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 53d88adf4305ff8..04a2223c4448e10 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -72,26 +72,18 @@ func.func @dense_template_argument(%arg : i32) {
 
 // -----
 
-func.func @empty_operator(%arg : i32) {
-    // expected-error @+1 {{'emitc.apply' op applicable operator must not be empty}}
-    %2 = emitc.apply ""(%arg) : (i32) -> !emitc.ptr<i32>
-    return
-}
-
-// -----
-
-func.func @illegal_operator(%arg : i32) {
-    // expected-error @+1 {{'emitc.apply' op applicable operator is illegal}}
-    %2 = emitc.apply "+"(%arg) : (i32) -> !emitc.ptr<i32>
+func.func @illegal_address_of_operand() {
+    %1 = "emitc.constant"(){value = 42: i32} : () -> i32
+    // expected-error @+1 {{'emitc.address_of' op requires operand to be a variable}}
+    %2 = emitc.address_of %1 : (i32) -> !emitc.ptr<i32>
     return
 }
 
 // -----
 
-func.func @illegal_operand() {
-    %1 = "emitc.constant"(){value = 42: i32} : () -> i32
-    // expected-error @+1 {{'emitc.apply' op cannot apply to constant}}
-    %2 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
+func.func @illegal_dereference_operand(%arg0 : !emitc.ptr<i32>) {
+    // expected-error @+1 {{'emitc.dereference' op requires result to be of type pointed to by operand}}
+    %2 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> (f32)
     return
 }
 
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 6c8398680980466..50358229b787b96 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -25,9 +25,10 @@ func.func @c() {
   return
 }
 
-func.func @a(%arg0: i32, %arg1: i32) {
-  %1 = "emitc.apply"(%arg0) {applicableOperator = "&"} : (i32) -> !emitc.ptr<i32>
-  %2 = emitc.apply "&"(%arg1) : (i32) -> !emitc.ptr<i32>
+func.func @a() {
+  %arg0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %1 = "emitc.address_of"(%arg0) : (i32) -> !emitc.ptr<i32>
+  %2 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
   return
 }
 
@@ -47,6 +48,12 @@ func.func @div_int(%arg0: i32, %arg1: i32) {
   return
 }
 
+func.func @dereference(%arg0: !emitc.ptr<i32>) {
+  %1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> (i32)
+  %2 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> (i32)
+  return
+}
+
 func.func @div_float(%arg0: f32, %arg1: f32) {
   %1 = "emitc.div" (%arg0, %arg1) : (f32, f32) -> f32
   return
diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir
index 252f5e214840da5..b1280a7345328d8 100644
--- a/mlir/test/Target/Cpp/common-cpp.mlir
+++ b/mlir/test/Target/Cpp/common-cpp.mlir
@@ -82,10 +82,11 @@ func.func @opaque_types(%arg0: !emitc.opaque<"bool">, %arg1: !emitc.opaque<"char
   return %2 : !emitc.opaque<"status_t">
 }
 
-func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
+func.func @apply() -> !emitc.ptr<i32> {
+  %arg0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
   // CHECK: int32_t* [[V2]] = &[[V1]];
-  %0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
+  %0 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
   // CHECK: int32_t [[V3]] = *[[V2]];
-  %1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
+  %1 = emitc.dereference %0 : (!emitc.ptr<i32>) -> (i32)
   return %0 : !emitc.ptr<i32>
 }



More information about the Mlir-commits mailing list