[Mlir-commits] [mlir] [mlir][emitc][cf] add 'cf.switch' support in CppEmitter (PR #101478)

Andrey Timonin llvmlistbot at llvm.org
Thu Aug 1 07:07:32 PDT 2024


https://github.com/EtoAndruwa updated https://github.com/llvm/llvm-project/pull/101478

>From cbc69f9ba73c2f7f5df042f78fb38a29e3748252 Mon Sep 17 00:00:00 2001
From: EtoAndruwa <timonina1909 at gmail.com>
Date: Thu, 1 Aug 2024 15:05:11 +0300
Subject: [PATCH] [mlir][emitc][cf] add 'cf.switch' support in CppEmitter

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 65 ++++++++++++++++++++++++--
 mlir/test/Target/Cpp/switch.mlir       | 41 ++++++++++++++++
 2 files changed, 103 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/switch.mlir

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 1dadb9dd691e7..9e988b9731e04 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -579,6 +579,65 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    cf::SwitchOp switchOp) {
+  raw_indented_ostream &os = emitter.ostream();
+  auto iteratorCaseValues = (*switchOp.getCaseValues()).begin();
+  auto iteratorCaseValuesEnd = (*switchOp.getCaseValues()).end();
+  size_t caseIndex = 0;
+
+  os << "\nswitch(" << emitter.getOrCreateName(switchOp.getFlag()) << ") {";
+
+  for (const auto caseBlock : switchOp.getCaseDestinations()) {
+    if (iteratorCaseValues == iteratorCaseValuesEnd)
+      return switchOp.emitOpError("case's value is absent for case block");
+
+    os << "\ncase "
+       << "(" << *(iteratorCaseValues++) << ")"
+       << ": {\n";
+    os.indent();
+
+    for (auto pair : llvm::zip(switchOp.getCaseOperands(caseIndex++),
+                               caseBlock->getArguments())) {
+      Value &operand = std::get<0>(pair);
+      BlockArgument &argument = std::get<1>(pair);
+      os << emitter.getOrCreateName(argument) << " = "
+         << emitter.getOrCreateName(operand) << ";\n";
+    }
+
+    os << "goto ";
+
+    if (!(emitter.hasBlockLabel(*caseBlock)))
+      return switchOp.emitOpError("unable to find label for case block");
+    os << emitter.getOrCreateName(*caseBlock) << ";\n";
+
+    os.unindent() << "}";
+  }
+
+  os << "\ndefault: {\n";
+  os.indent();
+
+  for (auto pair :
+       llvm::zip(switchOp.getDefaultOperands(),
+                 (switchOp.getDefaultDestination())->getArguments())) {
+    Value &operand = std::get<0>(pair);
+    BlockArgument &argument = std::get<1>(pair);
+    os << emitter.getOrCreateName(argument) << " = "
+       << emitter.getOrCreateName(operand) << ";\n";
+  }
+
+  os << "goto ";
+
+  if (!(emitter.hasBlockLabel(*switchOp.getDefaultDestination())))
+    return switchOp.emitOpError("unable to find label for default block");
+  os << emitter.getOrCreateName(*switchOp.getDefaultDestination()) << ";\n";
+
+  os.unindent() << "}\n";
+  os << "}\n";
+
+  return success();
+}
+
 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
                                         StringRef callee) {
   if (failed(emitter.emitAssignPrefix(*callOp)))
@@ -997,8 +1056,8 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
       // When generating code for an emitc.for and emitc.verbatim op, printing a
       // trailing semicolon is handled within the printOperation function.
       bool trailingSemicolon =
-          !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
-               emitc::IfOp, emitc::VerbatimOp>(op);
+          !isa<cf::CondBranchOp, cf::SwitchOp, emitc::DeclareFuncOp,
+               emitc::ForOp, emitc::IfOp, emitc::VerbatimOp>(op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1496,7 +1555,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           // Builtin ops.
           .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
           // CF ops.
-          .Case<cf::BranchOp, cf::CondBranchOp>(
+          .Case<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
               [&](auto op) { return printOperation(*this, op); })
           // EmitC ops.
           .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
diff --git a/mlir/test/Target/Cpp/switch.mlir b/mlir/test/Target/Cpp/switch.mlir
new file mode 100644
index 0000000000000..711a581803140
--- /dev/null
+++ b/mlir/test/Target/Cpp/switch.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+func.func @switch_func(%a: i32, %b: i32, %c: i32) -> () {
+    cf.switch %b : i32, [
+    default: ^bb1(%a : i32),
+    42: ^bb1(%b : i32),
+    43: ^bb2(%c : i32),
+    44: ^bb3(%c : i32)
+    ]
+
+    ^bb1(%x1 : i32) :
+        %y1 = "emitc.add" (%x1, %x1) : (i32, i32) -> i32
+        return 
+
+    ^bb2(%x2 : i32) :
+        %y2 = "emitc.sub" (%x2, %x2) : (i32, i32) -> i32
+        return 
+
+    ^bb3(%x3 : i32) :
+        %y3 = "emitc.mul" (%x3, %x3) : (i32, i32) -> i32
+        return 
+}
+// CHECK: void switch_func(int32_t [[V0:[^ ]*]], int32_t [[V1:[^ ]*]], int32_t [[V2:[^ ]*]]) {
+// CHECK: switch([[V1:[^ ]*]]) {
+// CHECK-NEXT: case (42): {
+// CHECK-NEXT: v7 = v2;
+// CHECK-NEXT: goto label2;
+// CHECK-NEXT: }
+// CHECK-NEXT: case (43): {
+// CHECK-NEXT: v8 = v3;
+// CHECK-NEXT: goto label3;
+// CHECK-NEXT: }
+// CHECK-NEXT: case (44): {
+// CHECK-NEXT: v9 = v3;
+// CHECK-NEXT: goto label4;
+// CHECK-NEXT: }
+// CHECK-NEXT: default: {
+// CHECK-NEXT: v7 = v1;
+// CHECK-NEXT: goto label2;
+// CHECK-NEXT: }
+// CHECK-NEXT: }



More information about the Mlir-commits mailing list