[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.iterate` operation (PR #88955)

Aart Bik llvmlistbot at llvm.org
Wed Apr 17 14:42:44 PDT 2024


================
@@ -2063,6 +2163,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+    return failure();
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
+
+  iters.append(iterArgs);
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, iters))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+                              Block::BlockArgListType blocksArgs,
+                              LevelSet crdUsedLvls) {
+  if (crdUsedLvls.empty())
+    return;
+
+  p << " at(";
+  for (unsigned i = 0; i < spaceDim; i++) {
+    if (crdUsedLvls[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != spaceDim - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  for (auto [i, init, iter, yield, ret] :
+       llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+    if (init.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (iter.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (yield.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+  }
+
+  return success();
+}
+
+/// IterateOp implemented OpInterfaces' methods.
----------------
aartbik wrote:

why paste tense?

https://github.com/llvm/llvm-project/pull/88955


More information about the Mlir-commits mailing list