[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.iterate` operation (PR #88955)
Peiming Liu
llvmlistbot at llvm.org
Tue May 21 11:29:12 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88955
>From 221c7a85de33645cc88c980447bf4678000e4748 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 12 Apr 2024 22:03:06 +0000
Subject: [PATCH 1/5] [mlir][sparse] introduce `sparse_tensor.iterate`
operation
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 38 +++
.../SparseTensor/IR/SparseTensorAttrDefs.td | 15 ++
.../SparseTensor/IR/SparseTensorOps.td | 101 ++++++-
.../SparseTensor/IR/SparseTensorDialect.cpp | 255 ++++++++++++++++++
mlir/test/Dialect/SparseTensor/invalid.mlir | 57 ++++
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 28 ++
.../SparseTensor/sparse_itertion_licm.mlir | 26 ++
7 files changed, 519 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3cf81d2e58f21..239118a6575d8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/bit.h"
+
//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,40 @@ struct COOSegment {
}
};
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
+class LevelSet {
+ uint64_t bits = 0;
+
+public:
+ LevelSet() = default;
+ explicit LevelSet(uint64_t bits) : bits(bits) {}
+ operator uint64_t() const { return bits; }
+
+ LevelSet &set(unsigned i) {
+ assert(i < 64);
+ bits |= 1 << i;
+ return *this;
+ }
+
+ LevelSet &operator|=(LevelSet lhs) {
+ bits |= static_cast<uint64_t>(lhs);
+ return *this;
+ }
+
+ LevelSet &lshift(unsigned offset) {
+ bits = bits << offset;
+ return *this;
+ }
+
+ bool operator[](unsigned i) const {
+ assert(i < 64);
+ return (bits & (1 << i)) != 0;
+ }
+
+ unsigned count() const { return llvm::popcount(bits); }
+ bool empty() const { return bits == 0; }
+};
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 53dd8e39438cc..c26193af09e81 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+ TypedAttrBase<
+ I64, "IntegerAttr",
+ And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+ "LevelSet attribute"> {
+ let returnType = [{::mlir::sparse_tensor::LevelSet}];
+ let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4e4441c640ed9..0ef03a5afa733 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
- "ForeachOp"]>]> {
+ "ForeachOp", "IterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
let hasVerifier = 1;
}
+def IterateOp : SparseTensor_Op<"iterate",
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+ "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+ let summary = "Iterate over a sparse iteration space";
+ let description = [{
+ The `sparse_tensor.iterate` operations represents a loop over the
+ provided iteration space extracted from a specific sparse tensor.
+ The operation defines an SSA value for a sparse iterator that points
+ to the current stored element in the sparse tensor and SSA values
+ for coordinates of the stored element. The coordinates are always
+ converted to `index` type despite of the underlying sparse tensor
+ storage. When coordinates are not used, the SSA values can be skipped
+ by `_` symbols, which usually leads to simpler generated code after
+ sparsification. For example:
+
+ ```mlir
+ // The coordinate for level 0 is not used when iterating over a 2-D
+ // iteration space.
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+ ```
+
+ `sparse_tensor.iterate` can also operate on loop-carried variables
+ and returns the final values after loop termination.
+ The initial values of the variables are passed as additional SSA operands
+ to the iterator SSA value and used coordinate SSA values mentioned
+ above. The operation region has an argument for the iterator, variadic
+ arguments for specified (used) coordiates and followed by one argument
+ for each loop-carried variable, representing the value of the variable
+ at the current iteration.
+ The body region must contain exactly one block that terminates with
+ `sparse_tensor.yield`.
+
+ `sparse_tensor.iterate` results hold the final values after the last
+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
+ must be explicitly present.
+ The number and types of the `sparse_tensor.iterate` results must match
+ the initial values in the iter_args binding and the yield operands.
+
+
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
+ stored in the sparse input:
+
+ ```mlir
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+ // Iterates over the first level of %sp
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
+ // Iterates over the second level of %sp
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
+ vector.print %crd0 : index
+ vector.print %crd1 : index
+ }
+ }
+ }
+
+ ```
+ }];
+
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
+ Variadic<AnyType>:$initArgs,
+ LevelSetAttr:$crdUsedLvls);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getIterSpace().getType().getSpaceDim();
+ }
+ BlockArgument getIterator() {
+ return getRegion().getArguments().front();
+ }
+ Block::BlockArgListType getCrds() {
+ // The first block argument is iterator, the remaining arguments are
+ // referenced coordinates.
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ }
+ unsigned getNumRegionIterArgs() {
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4adb1c19096a2..0a9254573538d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2130,6 +2130,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
printLevelRange(p, lo, hi);
}
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+ // Parses "%iters, ... in %spaces, ..."
+ if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+ parser.parseOperandList(spaces))
+ return failure();
+
+ if (iterators.size() != spaces.size())
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of sparse iterators and sparse spaces");
+
+ // Parse "at(%crd0, _, ...)"
+ LevelSet crdUsedLvlSet;
+ bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+ unsigned lvlCrdCnt = 0;
+ if (hasUsedCrds) {
+ ParseResult crdList = parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ if (parser.parseOptionalKeyword("_")) {
+ if (parser.parseArgument(iterArgs.emplace_back()))
+ return failure();
+ // Always use IndexType for the coordinate.
+ crdUsedLvlSet.set(lvlCrdCnt);
+ iterArgs.back().type = parser.getBuilder().getIndexType();
+ }
+ lvlCrdCnt += 1;
+ return success();
+ });
+ if (failed(crdList)) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expecting SSA value or \"_\" for level coordinates");
+ }
+ }
+ // Set the CrdUsedLvl bitset.
+ state.addAttribute("crdUsedLvls",
+ parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+ // Parse "iter_args(%arg = %init, ...)"
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(iterArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": sparse_tensor.iter_space -> ret"
+ if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+ return failure();
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+ IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+ if (!spaceTp)
+ return parser.emitError(parser.getNameLoc(),
+ "expected sparse_tensor.iter_space type for "
+ "iteration space operands");
+ if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space dimension "
+ "and specified coordinates");
+ it.type = spaceTp.getIteratorType();
+ }
+
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(state.types))
+ return failure();
+
+ // Resolves input operands.
+ if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+ state.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ unsigned numCrds = crdUsedLvlSet.count();
+ // Strip off leading args that used for coordinates.
+ MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+ if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of iteration arguments and return values");
+ }
+
+ for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+ it.type = tp;
+ if (parser.resolveOperand(init, tp, state.operands))
+ return failure();
+ }
+ }
+ return success();
+}
+
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2166,6 +2266,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ SmallVector<OpAsmParser::Argument> iters, iterArgs;
+ if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+ return failure();
+ if (iters.size() != 1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected only one iterator/iteration space");
+
+ iters.append(iterArgs);
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, iters))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+ Block::BlockArgListType blocksArgs,
+ LevelSet crdUsedLvls) {
+ if (crdUsedLvls.empty())
+ return;
+
+ p << " at(";
+ for (unsigned i = 0; i < spaceDim; i++) {
+ if (crdUsedLvls[i]) {
+ p << blocksArgs.front();
+ blocksArgs = blocksArgs.drop_front();
+ } else {
+ p << "_";
+ }
+ if (i != spaceDim - 1)
+ p << ", ";
+ }
+ assert(blocksArgs.empty());
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+ printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+ p << " : " << getIterSpace().getType() << " ";
+ if (!getInitArgs().empty())
+ p << "-> (" << getInitArgs().getTypes() << ") ";
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+ if (getInitArgs().size() != getNumResults()) {
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ }
+ return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+ if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+ return emitOpError("mismatch in iterator and iteration space type");
+ if (getNumRegionIterArgs() != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
+ auto iterArgs = getRegionIterArgs();
+ auto yieldVals = getYieldedValues();
+ auto opResults = getResults();
+ if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+ opResults.size()})) {
+ return emitOpError() << "number mismatch between iter args and results.";
+ }
+
+ for (auto [i, init, iter, yield, ret] :
+ llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+ if (init.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter operand and defined value";
+ if (iter.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter region arg and defined value";
+ if (yield.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th yield value and defined value";
+ }
+
+ return success();
+}
+
+/// IterateOp implemented OpInterfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+ return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+ return getRegion().getArguments().take_back(getNumRegionIterArgs());
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+ return cast<sparse_tensor::YieldOp>(
+ getRegion().getBlocks().front().getTerminator())
+ .getResultsMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself.
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ // It is possible for loop not to enter the body.
+ regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3fa696e1600a9..b13024cd4ed99 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1094,3 +1094,60 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return
}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
+ %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
+ sparse_tensor.yield %si : index
+ }
+ return %r1 : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// expected-note at +1 {{prior use here}}
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
+ sparse_tensor.yield %outer : f32
+ }
+ return %r1 : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
+ %y = arith.constant 1.0 : f32
+ sparse_tensor.yield %y : f32
+ }
+ return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d34071279e512..e9a898f16b41d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -763,3 +763,31 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_iterate(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME: %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: index) -> index {
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK: sparse_tensor.yield %[[VAL_7]] : index
+// CHECK: }
+// CHECK: return %[[VAL_4]] : index
+// CHECK: }
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ sparse_tensor.yield %outer : index
+ }
+ return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
new file mode 100644
index 0000000000000..e7158d04b37fe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : dense,
+ j : compressed
+ )
+}>
+
+// Make sure that pure instructions are hoisted outside the loop.
+//
+// CHECK: sparse_tensor.values
+// CHECK: sparse_tensor.positions
+// CHECK: sparse_tensor.coordinate
+// CHECK: sparse_tensor.iterate
+func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+ sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
+ %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
+ %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+ %2 = sparse_tensor.coordinates %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+ "test.op"(%0, %1, %2) : (memref<?xf64>, memref<?xindex>, memref<?xindex>) -> ()
+ }
+
+ return
+}
>From c6368bc924503841281ef539adbb084ac7cff530 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 21:50:31 +0000
Subject: [PATCH 2/5] address comments
---
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h | 4 ++--
.../mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td | 2 +-
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 4 ++--
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
4 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 239118a6575d8..338adf17be891 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -58,7 +58,7 @@ struct COOSegment {
}
};
-/// A simple wrapper to encode a bitset of defined (at most 64) levels.
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
class LevelSet {
uint64_t bits = 0;
@@ -69,7 +69,7 @@ class LevelSet {
LevelSet &set(unsigned i) {
assert(i < 64);
- bits |= 1 << i;
+ bits |= static_cast<uint64_t>(0x01u) << i;
return *this;
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index c26193af09e81..69b212cce4ceb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -20,7 +20,7 @@ class SparseTensor_Attr<string name,
: AttrDef<SparseTensor_Dialect, name, traits>;
//===----------------------------------------------------------------------===//
-// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// A simple bitset attribute wrapped around a single int64_t to encode a set of
// sparse tensor levels.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0ef03a5afa733..735a35c139834 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1526,8 +1526,8 @@ def IterateOp : SparseTensor_Op<"iterate",
let summary = "Iterate over a sparse iteration space";
let description = [{
- The `sparse_tensor.iterate` operations represents a loop over the
- provided iteration space extracted from a specific sparse tensor.
+ The `sparse_tensor.iterate` operation represents a loop (nest) over
+ the provided iteration space extracted from a specific sparse tensor.
The operation defines an SSA value for a sparse iterator that points
to the current stored element in the sparse tensor and SSA values
for coordinates of the stored element. The coordinates are always
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0a9254573538d..b5b92f9bfa0d2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2385,7 +2385,7 @@ LogicalResult IterateOp::verifyRegions() {
return success();
}
-/// IterateOp implemented OpInterfaces' methods.
+/// OpInterfaces' methods implemented by IterateOp.
SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
>From d590b70e925c4da5d4097356e83ac73bc52d9a51 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 May 2024 21:00:26 +0000
Subject: [PATCH 3/5] use coord instead of crd
---
.../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 735a35c139834..2cc367931ce45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1568,15 +1568,15 @@ def IterateOp : SparseTensor_Op<"iterate",
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
// Iterates over the first level of %sp
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
- %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
// Iterates over the second level of %sp
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
- %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
- vector.print %crd0 : index
- vector.print %crd1 : index
+ vector.print %coord0 : index
+ vector.print %coord1 : index
}
}
}
>From 777b98d2d9d6c4f9c8123593eed2c726d94cf0e2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 May 2024 21:25:46 +0000
Subject: [PATCH 4/5] make iter_space type explicit
---
.../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 12 ++++++------
mlir/test/Dialect/SparseTensor/invalid.mlir | 12 ++++++++----
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 5 +++--
.../Dialect/SparseTensor/sparse_itertion_licm.mlir | 1 +
4 files changed, 18 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2cc367931ce45..6b337d58cf297 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1510,7 +1510,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
LevelAttr:$loLvl, LevelAttr:$hiLvl);
let results = (outs AnySparseIterSpace:$resultSpace);
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
- " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
let hasVerifier = 1;
}
@@ -1543,8 +1543,8 @@ def IterateOp : SparseTensor_Op<"iterate",
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
```
- `sparse_tensor.iterate` can also operate on loop-carried variables
- and returns the final values after loop termination.
+ `sparse_tensor.iterate` can also operate on loop-carried variables.
+ It returns the final values after loop termination.
The initial values of the variables are passed as additional SSA operands
to the iterator SSA value and used coordinate SSA values mentioned
above. The operation region has an argument for the iterator, variadic
@@ -1554,9 +1554,9 @@ def IterateOp : SparseTensor_Op<"iterate",
The body region must contain exactly one block that terminates with
`sparse_tensor.yield`.
- `sparse_tensor.iterate` results hold the final values after the last
- iteration. If the `sparse_tensor.iterate` defines any values, a yield
- must be explicitly present.
+ The results of an `sparse_tensor.iterate` hold the final values after
+ the last iteration. If the `sparse_tensor.iterate` defines any values,
+ a yield must be explicitly present.
The number and types of the `sparse_tensor.iterate` results must match
the initial values in the iter_args binding and the yield operands.
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b13024cd4ed99..eb0dc01be25b9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+ -> !sparse_tensor.iter_space<#COO, lvls = 0 to 2>
return
}
@@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1092,6 +1095,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 2>
return
}
@@ -1106,7 +1110,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
}>
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
%r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
sparse_tensor.yield %si : index
@@ -1125,7 +1129,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
// expected-note at +1 {{prior use here}}
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
sparse_tensor.yield %outer : f32
@@ -1143,7 +1147,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
}>
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
%y = arith.constant 1.0 : f32
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e9a898f16b41d..bce0b41a99828 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -758,9 +758,10 @@ func.func @sparse_has_runtime() -> i1 {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
-> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
// Extracting the iteration space for the first level needs no parent iterator.
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// Extracting the iteration space for the second level needs a parent iterator.
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
}
@@ -785,7 +786,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
// CHECK: return %[[VAL_4]] : index
// CHECK: }
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
sparse_tensor.yield %outer : index
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
index e7158d04b37fe..f70fab3b7251d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -15,6 +15,7 @@
// CHECK: sparse_tensor.iterate
func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+ -> !sparse_tensor.iter_space<#CSR, lvls = 0>
sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
%0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
%1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
>From 9c4a33f4bed97f2147306839c19abcaa4483d53e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 May 2024 17:17:53 +0000
Subject: [PATCH 5/5] update example code
---
.../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 12 ++++++++----
.../Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
2 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6b337d58cf297..3b223c2cf3863 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1489,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
// Extracts a 1-D iteration space from a COO tensor at level 1.
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ ->!sparse_tensor.iter_space<#COO, lvls = 1>
```
}];
@@ -1501,16 +1502,17 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
return getHiLvl() - getLoLvl();
}
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
- return getResultSpace().getType().getLvlTypes();
+ return getExtractedSpace().getType().getLvlTypes();
}
}];
let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);
- let results = (outs AnySparseIterSpace:$resultSpace);
+ let results = (outs AnySparseIterSpace:$extractedSpace);
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
- " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
+ "`->` qualified(type($extractedSpace))";
let hasVerifier = 1;
}
@@ -1567,12 +1569,14 @@ def IterateOp : SparseTensor_Op<"iterate",
```mlir
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
// Iterates over the first level of %sp
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
%r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
// Iterates over the second level of %sp
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
%r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
vector.print %coord0 : index
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b5b92f9bfa0d2..d2de6bc0dbb85 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2253,7 +2253,7 @@ LogicalResult ExtractIterSpaceOp::verify() {
}
if (pIter) {
- IterSpaceType spaceTp = getResultSpace().getType();
+ IterSpaceType spaceTp = getExtractedSpace().getType();
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
return emitOpError(
"mismatch in parent iterator encoding and iteration space encoding.");
More information about the Mlir-commits
mailing list