[Mlir-commits] [mlir] Unify handling of operations which are emitted in a deferred way (PR #97804)

Simon Camphausen llvmlistbot at llvm.org
Mon Jul 8 02:36:04 PDT 2024


https://github.com/simon-camp updated https://github.com/llvm/llvm-project/pull/97804

>From f06761b99fd3dca515969694f83227f4f64c15c6 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Thu, 6 Jun 2024 07:46:53 +0000
Subject: [PATCH 1/3] Unify handling of operations which are emitted in a
 deferred way

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 104 ++++++++++++++-----------
 1 file changed, 58 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 626638282efe1..038220d479068 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -174,6 +174,9 @@ struct CppEmitter {
   /// Emit an expression as a C expression.
   LogicalResult emitExpression(ExpressionOp expressionOp);
 
+  /// Insert the expression representing the operation into the value cache.
+  LogicalResult cacheDeferredOpResult(Operation *op);
+
   /// Return the existing or a new name for a Value.
   StringRef getOrCreateName(Value val);
 
@@ -273,6 +276,12 @@ struct CppEmitter {
 };
 } // namespace
 
+/// Determine whether expression \p op should be emitted in a deferred way.
+static bool hasDeferredEmission(Operation *op) {
+  return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp,
+                         emitc::SubscriptOp>(op);
+}
+
 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
 /// as part of its user. This function recommends inlining of any expressions
 /// that can be inlined unless it is used by another expression, under the
@@ -295,10 +304,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
 
   Operation *user = *result.getUsers().begin();
 
-  // Do not inline expressions used by subscript operations, since the
-  // way the subscript operation translation is implemented requires that
-  // variables be materialized.
-  if (isa<emitc::SubscriptOp>(user))
+  // Do not inline expressions used by operations with deferred emission, since
+  // the way their translation is implemented requires that variables be
+  // materialized.
+  if (hasDeferredEmission(user))
     return false;
 
   // Do not inline expressions used by ops with the CExpression trait. If this
@@ -370,20 +379,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return emitter.emitOperand(assignOp.getValue());
 }
 
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    emitc::GetGlobalOp op) {
-  // Add name to cache so that `hasValueInScope` works.
-  emitter.getOrCreateName(op.getResult());
-  return success();
-}
-
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    emitc::SubscriptOp subscriptOp) {
-  // Add name to cache so that `hasValueInScope` works.
-  emitter.getOrCreateName(subscriptOp.getResult());
-  return success();
-}
-
 static LogicalResult printBinaryOperation(CppEmitter &emitter,
                                           Operation *operation,
                                           StringRef binaryOperator) {
@@ -621,9 +616,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
       if (t.getType().isIndex()) {
         int64_t idx = t.getInt();
         Value operand = op.getOperand(idx);
-        auto literalDef =
-            dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
-        if (!literalDef && !emitter.hasValueInScope(operand))
+        if (!emitter.hasValueInScope(operand))
           return op.emitOpError("operand ")
                  << idx << "'s value not defined in scope";
         os << emitter.getOrCreateName(operand);
@@ -948,8 +941,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
     // regions.
     WalkResult result =
         functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
-          if (isa<emitc::LiteralOp>(op) ||
-              isa<emitc::ExpressionOp>(op->getParentOp()) ||
+          if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
               (isa<emitc::ExpressionOp>(op) &&
                shouldBeInlined(cast<emitc::ExpressionOp>(op))))
             return WalkResult::skip();
@@ -1001,7 +993,8 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
       // trailing semicolon is handled within the printOperation function.
       bool trailingSemicolon =
           !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
-               emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
+               emitc::IfOp, emitc::VerbatimOp>(op) ||
+          hasDeferredEmission(&op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1134,20 +1127,41 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
   return out;
 }
 
+LogicalResult CppEmitter::cacheDeferredOpResult(Operation *op) {
+  if (op->getNumResults() != 1)
+    return op->emitError("Adding deferred ops into value cache only works for "
+                         "single result operations, got ")
+           << op->getNumResults() << " results";
+
+  Value result = op->getResult(0);
+  if (valueMapper.count(result))
+    return success();
+
+  if (auto getGlobal = dyn_cast<emitc::GetGlobalOp>(op)) {
+    valueMapper.insert(result, getGlobal.getName().str());
+    return success();
+  }
+
+  if (auto literal = dyn_cast<emitc::LiteralOp>(op)) {
+    valueMapper.insert(result, literal.getValue().str());
+    return success();
+  }
+
+  if (auto subscript = dyn_cast<emitc::SubscriptOp>(op)) {
+    valueMapper.insert(result, getSubscriptName(subscript));
+    return success();
+  }
+
+  return op->emitError("cacheDeferredOpResult not implemented");
+}
+
 /// Return the existing or a new name for a Value.
 StringRef CppEmitter::getOrCreateName(Value val) {
-  if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
-    return literal.getValue();
   if (!valueMapper.count(val)) {
-    if (auto subscript =
-            dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
-      valueMapper.insert(val, getSubscriptName(subscript));
-    } else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
-                   val.getDefiningOp())) {
-      valueMapper.insert(val, getGlobal.getName().str());
-    } else {
-      valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
-    }
+    assert(!hasDeferredEmission(val.getDefiningOp()) &&
+           "cacheDeferredOpResult should have been called on this value, "
+           "update the emitOperation function.");
+    valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
   }
   return *valueMapper.begin(val);
 }
@@ -1341,9 +1355,6 @@ LogicalResult CppEmitter::emitOperand(Value value) {
   if (expressionOp && shouldBeInlined(expressionOp))
     return emitExpression(expressionOp);
 
-  auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
-  if (!literalOp && !hasValueInScope(value))
-    return failure();
   os << getOrCreateName(value);
   return success();
 }
@@ -1399,7 +1410,7 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
 
 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
                                                   bool trailingSemicolon) {
-  if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
+  if (hasDeferredEmission(result.getDefiningOp()))
     return success();
   if (hasValueInScope(result)) {
     return result.getDefiningOp()->emitError(
@@ -1498,16 +1509,17 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
                 emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
                 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
-                emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
-                emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
-                emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
-                emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
-                emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
+                emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
+                emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
+                emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+                emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
               [&](auto op) { return printOperation(*this, op); })
-          .Case<emitc::LiteralOp>([&](auto op) { return success(); })
+          .Case<emitc::GetGlobalOp, emitc::LiteralOp, emitc::SubscriptOp>(
+              [&](Operation *op) { return cacheDeferredOpResult(op); })
           .Default([&](Operation *) {
             return op.emitOpError("unable to find printer for op");
           });
@@ -1515,7 +1527,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   if (failed(status))
     return failure();
 
-  if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
+  if (hasDeferredEmission(&op))
     return success();
 
   if (getEmittedExpression() ||

>From 3176b812958682bd1d53770d201f858ba11235ad Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 5 Jul 2024 10:03:17 +0000
Subject: [PATCH 2/3] Make globals assignable

---
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp      |  5 +-
 mlir/test/Dialect/EmitC/invalid_ops.mlir |  2 +-
 mlir/test/Dialect/EmitC/ops.mlir         |  6 ++
 mlir/test/Target/Cpp/global.mlir         | 84 +++++++++++++++++++-----
 4 files changed, 79 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b2556bb6065d8..9f99eb1233cb1 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -213,9 +213,10 @@ LogicalResult emitc::AssignOp::verify() {
   Value variable = getVar();
   Operation *variableDef = variable.getDefiningOp();
   if (!variableDef ||
-      !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
+      !llvm::isa<emitc::GetGlobalOp, emitc::SubscriptOp, emitc::VariableOp>(
+          variableDef))
     return emitOpError() << "requires first operand (" << variable
-                         << ") to be a Variable or subscript";
+                         << ") to be a get_global, subscript or variable";
 
   Value value = getValue();
   if (variable.getType() != value.getType())
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 8cd8bdca4df33..e9b11421882f9 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
 // -----
 
 func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
-  // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
+  // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a get_global, subscript or variable}}
   emitc.assign %arg1 : f32 to %arg2 : f32
   return
 }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 51c484a633eec..1d3ca5c9bc939 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -248,3 +248,9 @@ func.func @use_global(%i: index) -> f32 {
   %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
   return %1 : f32
 }
+
+func.func @assign_global(%arg0 : i32) {
+  %0 = emitc.get_global @myglobal_int : i32
+  emitc.assign %arg0 : i32 to %0 : i32
+  return
+}
diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir
index f0d92e862ae32..78b459836aa51 100644
--- a/mlir/test/Target/Cpp/global.mlir
+++ b/mlir/test/Target/Cpp/global.mlir
@@ -1,38 +1,92 @@
-// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
-// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
 
 emitc.global extern @decl : i8
-// CHECK: extern int8_t decl;
+// CPP-DEFAULT: extern int8_t decl;
+// CPP-DECLTOP: extern int8_t decl;
 
 emitc.global @uninit : i32
-// CHECK: int32_t uninit;
+// CPP-DEFAULT: int32_t uninit;
+// CPP-DECLTOP: int32_t uninit;
 
 emitc.global @myglobal_int : i32 = 4
-// CHECK: int32_t myglobal_int = 4;
+// CPP-DEFAULT: int32_t myglobal_int = 4;
+// CPP-DECLTOP: int32_t myglobal_int = 4;
 
 emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
-// CHECK: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
+// CPP-DEFAULT: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
+// CPP-DECLTOP: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
 
 emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
-// CHECK: const int16_t myconstant[2] = {2, 2};
+// CPP-DEFAULT: const int16_t myconstant[2] = {2, 2};
+// CPP-DECLTOP: const int16_t myconstant[2] = {2, 2};
 
 emitc.global extern const @extern_constant : !emitc.array<2xi16>
-// CHECK: extern const int16_t extern_constant[2];
+// CPP-DEFAULT: extern const int16_t extern_constant[2];
+// CPP-DECLTOP: extern const int16_t extern_constant[2];
 
 emitc.global static @static_var : f32
-// CHECK: static float static_var;
+// CPP-DEFAULT: static float static_var;
+// CPP-DECLTOP: static float static_var;
 
 emitc.global static @static_const : f32 = 3.0
-// CHECK: static float static_const = 3.000000000e+00f;
+// CPP-DEFAULT: static float static_const = 3.000000000e+00f;
+// CPP-DECLTOP: static float static_const = 3.000000000e+00f;
 
 emitc.global @opaque_init : !emitc.opaque<"char"> = #emitc.opaque<"CHAR_MIN">
-// CHECK: char opaque_init = CHAR_MIN;
+// CPP-DEFAULT: char opaque_init = CHAR_MIN;
+// CPP-DECLTOP: char opaque_init = CHAR_MIN;
 
-func.func @use_global(%i: index) -> f32 {
+func.func @use_global_scalar_read() -> i32 {
+  %0 = emitc.get_global @myglobal_int : i32
+  return %0 : i32
+}
+// CPP-DEFAULT-LABEL: int32_t use_global_scalar_read()
+// CPP-DEFAULT-NEXT: return myglobal_int;
+
+// CPP-DECLTOP-LABEL: int32_t use_global_scalar_read()
+// CPP-DECLTOP-NEXT: return myglobal_int;
+
+func.func @use_global_scalar_write(%arg0 : i32) {
+  %0 = emitc.get_global @myglobal_int : i32
+  emitc.assign %arg0 : i32 to %0 : i32 
+  return
+}
+// CPP-DEFAULT-LABEL: void use_global_scalar_write
+// CPP-DEFAULT-SAME: (int32_t [[V1:.*]])
+// CPP-DEFAULT-NEXT: myglobal_int = [[V1]];
+// CPP-DEFAULT-NEXT: return;
+
+// CPP-DECLTOP-LABEL: void use_global_scalar_write
+// CPP-DECLTOP-SAME: (int32_t [[V1:.*]])
+// CPP-DECLTOP-NEXT: myglobal_int = [[V1]];
+// CPP-DECLTOP-NEXT: return;
+
+func.func @use_global_array_read(%i: index) -> f32 {
   %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
   %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
   return %1 : f32
-  // CHECK-LABEL: use_global
-  // CHECK-SAME: (size_t [[V1:.*]])
-  // CHECK:   return myglobal[[[V1]]];
 }
+// CPP-DEFAULT-LABEL: float use_global_array_read
+// CPP-DEFAULT-SAME: (size_t [[V1:.*]])
+// CPP-DEFAULT-NEXT: return myglobal[[[V1]]];
+
+// CPP-DECLTOP-LABEL: float use_global_array_read
+// CPP-DECLTOP-SAME: (size_t [[V1:.*]])
+// CPP-DECLTOP-NEXT: return myglobal[[[V1]]];
+
+func.func @use_global_array_write(%i: index, %val : f32) {
+  %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
+  %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
+  emitc.assign %val : f32 to %1 : f32 
+  return
+}
+// CPP-DEFAULT-LABEL: void use_global_array_write
+// CPP-DEFAULT-SAME: (size_t [[V1:.*]], float [[V2:.*]])
+// CPP-DEFAULT-NEXT: myglobal[[[V1]]] = [[V2]];
+// CPP-DEFAULT-NEXT: return;
+
+// CPP-DECLTOP-LABEL: void use_global_array_write
+// CPP-DECLTOP-SAME: (size_t [[V1:.*]], float [[V2:.*]])
+// CPP-DECLTOP-NEXT: myglobal[[[V1]]] = [[V2]];
+// CPP-DECLTOP-NEXT: return;
\ No newline at end of file

>From e5c6319ace8eab8d8bf28cddd95b9c30ff34a582 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Mon, 8 Jul 2024 09:35:32 +0000
Subject: [PATCH 3/3] Refactor code

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 48 +++++++++-----------------
 1 file changed, 17 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 038220d479068..51d17481268e4 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -175,7 +175,7 @@ struct CppEmitter {
   LogicalResult emitExpression(ExpressionOp expressionOp);
 
   /// Insert the expression representing the operation into the value cache.
-  LogicalResult cacheDeferredOpResult(Operation *op);
+  void cacheDeferredOpResult(Value value, StringRef str);
 
   /// Return the existing or a new name for a Value.
   StringRef getOrCreateName(Value val);
@@ -993,8 +993,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
       // trailing semicolon is handled within the printOperation function.
       bool trailingSemicolon =
           !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
-               emitc::IfOp, emitc::VerbatimOp>(op) ||
-          hasDeferredEmission(&op);
+               emitc::IfOp, emitc::VerbatimOp>(op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1127,32 +1126,9 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
   return out;
 }
 
-LogicalResult CppEmitter::cacheDeferredOpResult(Operation *op) {
-  if (op->getNumResults() != 1)
-    return op->emitError("Adding deferred ops into value cache only works for "
-                         "single result operations, got ")
-           << op->getNumResults() << " results";
-
-  Value result = op->getResult(0);
-  if (valueMapper.count(result))
-    return success();
-
-  if (auto getGlobal = dyn_cast<emitc::GetGlobalOp>(op)) {
-    valueMapper.insert(result, getGlobal.getName().str());
-    return success();
-  }
-
-  if (auto literal = dyn_cast<emitc::LiteralOp>(op)) {
-    valueMapper.insert(result, literal.getValue().str());
-    return success();
-  }
-
-  if (auto subscript = dyn_cast<emitc::SubscriptOp>(op)) {
-    valueMapper.insert(result, getSubscriptName(subscript));
-    return success();
-  }
-
-  return op->emitError("cacheDeferredOpResult not implemented");
+void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
+  if (!valueMapper.count(value))
+    valueMapper.insert(value, str.str());
 }
 
 /// Return the existing or a new name for a Value.
@@ -1518,8 +1494,18 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
               [&](auto op) { return printOperation(*this, op); })
-          .Case<emitc::GetGlobalOp, emitc::LiteralOp, emitc::SubscriptOp>(
-              [&](Operation *op) { return cacheDeferredOpResult(op); })
+          .Case<emitc::GetGlobalOp>([&](auto op) {
+            cacheDeferredOpResult(op.getResult(), op.getName());
+            return success();
+          })
+          .Case<emitc::LiteralOp>([&](auto op) {
+            cacheDeferredOpResult(op.getResult(), op.getValue());
+            return success();
+          })
+          .Case<emitc::SubscriptOp>([&](auto op) {
+            cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
+            return success();
+          })
           .Default([&](Operation *) {
             return op.emitOpError("unable to find printer for op");
           });



More information about the Mlir-commits mailing list