[Mlir-commits] [mlir] [mlir][emitc] Add 'emitc.switch' op to the dialect (PR #102331)

Gil Rapaport llvmlistbot at llvm.org
Thu Aug 8 09:30:32 PDT 2024


================
@@ -1096,6 +1101,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+/// Parse the case regions and values.
+static ParseResult
+parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
+                 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+  SmallVector<int64_t> caseValues;
+  while (succeeded(parser.parseOptionalKeyword("case"))) {
+    int64_t value;
+    Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
+
+    if (parser.parseInteger(value) || parser.parseColon() ||
+        parser.parseRegion(region, /*arguments=*/{}))
+      return failure();
+    caseValues.push_back(value);
+  }
+  cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
+  return success();
+}
+
+/// Print the case regions and values.
+static void printSwitchCases(OpAsmPrinter &parser, Operation *op,
+                             DenseI64ArrayAttr cases, RegionRange caseRegions) {
+  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+    parser.printNewline();
+    parser << "case " << value << ": ";
+    parser.printRegion(*region, /*printEntryBlockArgs=*/false);
+  }
+  return;
+}
+
+ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand arg;
+  DenseI64ArrayAttr casesAttr;
+  SmallVector<std::unique_ptr<Region>, 2> caseRegionsRegions;
+  std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
+
+  if (parser.parseOperand(arg))
+    return failure();
+
+  Type argType;
+  // Parse the case's type.
+  if (parser.parseColon() || parser.parseType(argType))
+    return failure();
+
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+        return parser.emitError(loc)
+               << "'" << result.name.getStringRef() << "' op ";
+      })))
+    return failure();
+
+  auto odsResult = parseSwitchCases(parser, casesAttr, caseRegionsRegions);
+  if (odsResult)
+    return failure();
+
+  result.getOrAddProperties<SwitchOp::Properties>().cases = casesAttr;
+
+  if (parser.parseKeyword("default") || parser.parseColon())
+    return failure();
+
+  if (parser.parseRegion(*defaultRegionRegion))
+    return failure();
+
+  result.addRegion(std::move(defaultRegionRegion));
+  result.addRegions(caseRegionsRegions);
+
+  if (parser.resolveOperand(arg, argType, result.operands))
+    return failure();
+
+  return success();
+}
+
+void SwitchOp::print(OpAsmPrinter &parser) {
+  parser << ' ';
+  parser << getArg();
+  SmallVector<StringRef, 2> elidedAttrs;
+  elidedAttrs.push_back("cases");
+  parser.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+  parser << ' ';
+  printSwitchCases(parser, *this, getCasesAttr(), getCaseRegions());
+  parser.printNewline();
+  parser << "default";
+  parser << ' ';
+  parser.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/true,
+                     /*printBlockTerminators=*/true);
+
+  return;
+}
+
+static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
+                                  const Twine &name) {
+  auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
+  if (!yield)
+    return op.emitOpError("expected region to end with emitc.yield, but got ")
+           << region.front().back().getName();
+
+  if (yield.getNumOperands() != 0) {
+    return (op.emitOpError("expected each region to return ")
+            << "0 values, but " << name << " returns "
+            << yield.getNumOperands())
+               .attachNote(yield.getLoc())
+           << "see yield operation here";
+  }
+  return success();
+}
+
+LogicalResult emitc::SwitchOp::verify() {
+  if (!isSwitchOperandType(getArg().getType()))
+    return emitOpError("unsupported type ") << getArg().getType();
+
+  if (getCases().size() != getCaseRegions().size()) {
+    return emitOpError("has ")
+           << getCaseRegions().size() << " case regions but "
+           << getCases().size() << " case values";
+  }
+
+  DenseSet<int64_t> valueSet;
+  for (int64_t value : getCases())
+    if (!valueSet.insert(value).second)
+      return emitOpError("has duplicate case value: ") << value;
+
+  if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
+    return failure();
+
+  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
+    if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
+      return failure();
+
+  return success();
+}
+
+unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
+
+Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
+
+Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
+  assert(idx < getNumCases() && "case index out-of-bounds");
+  return getCaseRegions()[idx].front();
+}
+
+void SwitchOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
+  llvm::copy(getRegions(), std::back_inserter(successors));
+  return;
+}
+
+void SwitchOp::getEntrySuccessorRegions(
+    ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &successors) {
+  FoldAdaptor adaptor(operands, *this);
+
+  // If a constant was not provided, all regions are possible successors.
+  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
+  if (!arg) {
+    llvm::copy(getRegions(), std::back_inserter(successors));
+    return;
+  }
+
+  // Otherwise, try to find a case with a matching value. If not, the
+  // default region is the only successor.
+  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
+    if (caseValue == arg.getInt()) {
+      successors.emplace_back(&caseRegion);
+      return;
+    }
+  }
+  successors.emplace_back(&getDefaultRegion());
+  return;
+}
+
+void SwitchOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
+  if (!operandValue) {
+    // All regions are invoked at most once.
+    bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
+    return;
+  }
+
+  unsigned liveIndex = getNumRegions() - 1;
+  const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
+
+  liveIndex = iteratorToInt != getCases().end()
+                  ? std::distance(getCases().begin(), iteratorToInt)
+                  : liveIndex;
+
+  for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
+       ++regIndex)
+    bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
+
+  return;
----------------
aniragil wrote:

```suggestion
```

https://github.com/llvm/llvm-project/pull/102331


More information about the Mlir-commits mailing list