[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)

Yinying Li llvmlistbot at llvm.org
Wed Mar 20 10:37:25 PDT 2024


================
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
       isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
+      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
     return success();
 
   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
                      "reduce, select or foreach");
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
+  if (op.getLoLvl() + 1 == op.getHiLvl())
+    p << op.getLoLvl();
+  else
+    p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+      adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+  ret.push_back(IterSpaceType::get(ctx, lts));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  SparseTensorType stt = getSparseTensorType(getTensor());
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+  if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+    return emitOpError(
+        "mismatch in iteration space level types and tensor level types");
+  }
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError("parent iterator is only needed iff level low equals 0");
+  }
+
+  if (pIter) {
+    unsigned pDim = pIter.getType().getSpaceDim();
+    if (getLoLvl() < pDim || !stt.getLvlTypes()
+                                  .slice(getLoLvl() - pDim, pDim)
+                                  .equals(pIter.getType().getLvlTypes())) {
+      return emitOpError(
+          "mismatch in parent iterator level types and tensor level types");
+    }
+  }
+
+  return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  // Parses %iters in %spaces
+  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+      parser.parseOperand(iterSpace)) {
+    return failure();
+  }
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  // Region arguments starts with iterators and follows by optional
+  // user-provided iter_args.
+  regionArgs.push_back(iterator);
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(regionArgs, operands))
+      return failure();
+
+  // parse ": sparse_tensor.iter_space -> ret"
+  Type iterSpaceTps;
+  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+    return failure();
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(result.types))
+      return failure();
+
+  if (regionArgs.size() != result.types.size() + 1) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+  }
+
+  // Resolves input operands.
----------------
yinying-lisa-li wrote:

nit: Resolve.

https://github.com/llvm/llvm-project/pull/85958


More information about the Mlir-commits mailing list