[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)

Ingo Müller llvmlistbot at llvm.org
Fri Mar 22 07:19:11 PDT 2024


================
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
       isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
+      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
     return success();
 
   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
                      "reduce, select or foreach");
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
+  if (op.getLoLvl() + 1 == op.getHiLvl())
+    p << op.getLoLvl();
+  else
+    p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+      adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+  ret.push_back(IterSpaceType::get(ctx, lts));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  SparseTensorType stt = getSparseTensorType(getTensor());
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+  if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+    return emitOpError(
+        "mismatch in iteration space level types and tensor level types");
+  }
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError("parent iterator is only needed iff level low equals 0");
+  }
+
+  if (pIter) {
+    unsigned pDim = pIter.getType().getSpaceDim();
+    if (getLoLvl() < pDim || !stt.getLvlTypes()
+                                  .slice(getLoLvl() - pDim, pDim)
+                                  .equals(pIter.getType().getLvlTypes())) {
+      return emitOpError(
+          "mismatch in parent iterator level types and tensor level types");
+    }
+  }
+
+  return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  // Parses %iters in %spaces
+  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+      parser.parseOperand(iterSpace)) {
+    return failure();
+  }
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  // Region arguments starts with iterators and follows by optional
+  // user-provided iter_args.
+  regionArgs.push_back(iterator);
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(regionArgs, operands))
+      return failure();
+
+  // parse ": sparse_tensor.iter_space -> ret"
+  Type iterSpaceTps;
+  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+    return failure();
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(result.types))
+      return failure();
+
+  if (regionArgs.size() != result.types.size() + 1) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+  }
+
+  // Resolves input operands.
+  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto argOperandType :
+         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+      Type type = std::get<2>(argOperandType);
+      std::get<0>(argOperandType).type = type;
+      if (parser.resolveOperand(std::get<1>(argOperandType), type,
+                                result.operands))
+        return failure();
+    }
+  }
+
+  Region *body = result.addRegion();
+  regionArgs.front().type =
+      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
----------------
ingomueller-net wrote:

Here and below: I the [LLVM coding style](https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable) says that this type should be spelled out. (Here, I guess the type is an `ArrayRef<...>`, which is neither "obvious from the context" nor "abstracted away anyways"...)

https://github.com/llvm/llvm-project/pull/85958


More information about the Mlir-commits mailing list