[Mlir-commits] [mlir] e276cf0 - [mlir][sparse] introduce `sparse_tensor.iterate` operation (#88955)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 10 10:20:28 PDT 2024
Author: Peiming Liu
Date: 2024-06-10T10:20:24-07:00
New Revision: e276cf0831bfc71ef42eab78368aa487037ac27f
URL: https://github.com/llvm/llvm-project/commit/e276cf0831bfc71ef42eab78368aa487037ac27f
DIFF: https://github.com/llvm/llvm-project/commit/e276cf0831bfc71ef42eab78368aa487037ac27f.diff
LOG: [mlir][sparse] introduce `sparse_tensor.iterate` operation (#88955)
A `sparse_tensor.iterate` iterates over a sparse iteration space
extracted from `sparse_tensor.extract_iteration_space` operation
introduced in https://github.com/llvm/llvm-project/pull/88554.
Added:
mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3cf81d2e58f21..04a6386a199de 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/bit.h"
+
//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,42 @@ struct COOSegment {
}
};
+/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
+/// by `sparse_tensor.iterate` operation for the set of levels on which the
+/// coordinates should be loaded.
+class LevelSet {
+ uint64_t bits = 0;
+
+public:
+ LevelSet() = default;
+ explicit LevelSet(uint64_t bits) : bits(bits) {}
+ operator uint64_t() const { return bits; }
+
+ LevelSet &set(unsigned i) {
+ assert(i < 64);
+ bits |= static_cast<uint64_t>(0x01u) << i;
+ return *this;
+ }
+
+ LevelSet &operator|=(LevelSet lhs) {
+ bits |= static_cast<uint64_t>(lhs);
+ return *this;
+ }
+
+ LevelSet &lshift(unsigned offset) {
+ bits = bits << offset;
+ return *this;
+ }
+
+ bool operator[](unsigned i) const {
+ assert(i < 64);
+ return (bits & (1 << i)) != 0;
+ }
+
+ unsigned count() const { return llvm::popcount(bits); }
+ bool empty() const { return bits == 0; }
+};
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 53dd8e39438cc..69b212cce4ceb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped around a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+ TypedAttrBase<
+ I64, "IntegerAttr",
+ And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+ "LevelSet attribute"> {
+ let returnType = [{::mlir::sparse_tensor::LevelSet}];
+ let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4e4441c640ed9..5ae6f9f3443f8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
- "ForeachOp"]>]> {
+ "ForeachOp", "IterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
iteration space.
- The type of returned the value is automatically inferred to
+ The type of returned the value is must be
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
The returned iteration space can then be iterated over by
`sparse_tensor.iterate` operations to visit every stored element
@@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
// Extracts a 1-D iteration space from a COO tensor at level 1.
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ ->!sparse_tensor.iter_space<#COO, lvls = 1>
```
}];
@@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
return getHiLvl() - getLoLvl();
}
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
- return getResultSpace().getType().getLvlTypes();
+ return getExtractedSpace().getType().getLvlTypes();
}
}];
let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);
- let results = (outs AnySparseIterSpace:$resultSpace);
+ let results = (outs AnySparseIterSpace:$extractedSpace);
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
- " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
+ "`->` qualified(type($extractedSpace))";
let hasVerifier = 1;
}
+def IterateOp : SparseTensor_Op<"iterate",
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+ "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+ let summary = "Iterates over a sparse iteration space";
+ let description = [{
+ The `sparse_tensor.iterate` operation represents a loop (nest) over
+ the provided iteration space extracted from a specific sparse tensor.
+ The operation defines an SSA value for a sparse iterator that points
+ to the current stored element in the sparse tensor and SSA values
+ for coordinates of the stored element. The coordinates are always
+ converted to `index` type despite of the underlying sparse tensor
+ storage. When coordinates are not used, the SSA values can be skipped
+ by `_` symbols, which usually leads to simpler generated code after
+ sparsification. For example:
+
+ ```mlir
+ // The coordinate for level 0 is not used when iterating over a 2-D
+ // iteration space.
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+ ```
+
+ `sparse_tensor.iterate` can also operate on loop-carried variables.
+ It returns the final values after loop termination.
+ The initial values of the variables are passed as additional SSA operands
+ to the iterator SSA value and used coordinate SSA values mentioned
+ above. The operation region has an argument for the iterator, variadic
+ arguments for specified (used) coordiates and followed by one argument
+ for each loop-carried variable, representing the value of the variable
+ at the current iteration.
+ The body region must contain exactly one block that terminates with
+ `sparse_tensor.yield`.
+
+ The results of an `sparse_tensor.iterate` hold the final values after
+ the last iteration. If the `sparse_tensor.iterate` defines any values,
+ a yield must be explicitly present.
+ The number and types of the `sparse_tensor.iterate` results must match
+ the initial values in the iter_args binding and the yield operands.
+
+
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
+ stored in the sparse input:
+
+ ```mlir
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+ // Iterates over the first level of %sp
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
+ // Iterates over the second level of %sp
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
+ vector.print %coord0 : index
+ vector.print %coord1 : index
+ }
+ }
+ }
+
+ ```
+ }];
+
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
+ Variadic<AnyType>:$initArgs,
+ LevelSetAttr:$crdUsedLvls);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getIterSpace().getType().getSpaceDim();
+ }
+ BlockArgument getIterator() {
+ return getRegion().getArguments().front();
+ }
+ Block::BlockArgListType getCrds() {
+ // The first block argument is iterator, the remaining arguments are
+ // referenced coordinates.
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ }
+ unsigned getNumRegionIterArgs() {
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4adb1c19096a2..232d25d718c65 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2130,6 +2130,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
printLevelRange(p, lo, hi);
}
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+ // Parse "%iters, ... in %spaces, ..."
+ if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+ parser.parseOperandList(spaces))
+ return failure();
+
+ if (iterators.size() != spaces.size())
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of sparse iterators and sparse spaces");
+
+ // Parse "at(%crd0, _, ...)"
+ LevelSet crdUsedLvlSet;
+ bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+ unsigned lvlCrdCnt = 0;
+ if (hasUsedCrds) {
+ ParseResult crdList = parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ if (parser.parseOptionalKeyword("_")) {
+ if (parser.parseArgument(iterArgs.emplace_back()))
+ return failure();
+ // Always use IndexType for the coordinate.
+ crdUsedLvlSet.set(lvlCrdCnt);
+ iterArgs.back().type = parser.getBuilder().getIndexType();
+ }
+ lvlCrdCnt += 1;
+ return success();
+ });
+ if (failed(crdList)) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expecting SSA value or \"_\" for level coordinates");
+ }
+ }
+ // Set the CrdUsedLvl bitset.
+ state.addAttribute("crdUsedLvls",
+ parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+ // Parse "iter_args(%arg = %init, ...)"
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(iterArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": sparse_tensor.iter_space -> ret"
+ if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+ return failure();
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+ IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+ if (!spaceTp)
+ return parser.emitError(parser.getNameLoc(),
+ "expected sparse_tensor.iter_space type for "
+ "iteration space operands");
+ if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space dimension "
+ "and specified coordinates");
+ it.type = spaceTp.getIteratorType();
+ }
+
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(state.types))
+ return failure();
+
+ // Resolves input operands.
+ if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+ state.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ unsigned numCrds = crdUsedLvlSet.count();
+ // Strip off leading args that used for coordinates.
+ MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+ if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of iteration arguments and return values");
+ }
+
+ for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+ it.type = tp;
+ if (parser.resolveOperand(init, tp, state.operands))
+ return failure();
+ }
+ }
+ return success();
+}
+
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2153,7 +2253,7 @@ LogicalResult ExtractIterSpaceOp::verify() {
}
if (pIter) {
- IterSpaceType spaceTp = getResultSpace().getType();
+ IterSpaceType spaceTp = getExtractedSpace().getType();
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
return emitOpError(
"mismatch in parent iterator encoding and iteration space encoding.");
@@ -2166,6 +2266,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ SmallVector<OpAsmParser::Argument> iters, iterArgs;
+ if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+ return failure();
+ if (iters.size() != 1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected only one iterator/iteration space");
+
+ iters.append(iterArgs);
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, iters))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+ Block::BlockArgListType blocksArgs,
+ LevelSet crdUsedLvls) {
+ if (crdUsedLvls.empty())
+ return;
+
+ p << " at(";
+ for (unsigned i = 0; i < spaceDim; i++) {
+ if (crdUsedLvls[i]) {
+ p << blocksArgs.front();
+ blocksArgs = blocksArgs.drop_front();
+ } else {
+ p << "_";
+ }
+ if (i != spaceDim - 1)
+ p << ", ";
+ }
+ assert(blocksArgs.empty());
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+ printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+ p << " : " << getIterSpace().getType() << " ";
+ if (!getInitArgs().empty())
+ p << "-> (" << getInitArgs().getTypes() << ") ";
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+ if (getInitArgs().size() != getNumResults()) {
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ }
+ return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+ if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+ return emitOpError("mismatch in iterator and iteration space type");
+ if (getNumRegionIterArgs() != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
+ auto iterArgs = getRegionIterArgs();
+ auto yieldVals = getYieldedValues();
+ auto opResults = getResults();
+ if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+ opResults.size()})) {
+ return emitOpError() << "number mismatch between iter args and results.";
+ }
+
+ for (auto [i, init, iter, yield, ret] :
+ llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+ if (init.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter operand and defined value";
+ if (iter.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter region arg and defined value";
+ if (yield.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th yield value and defined value";
+ }
+
+ return success();
+}
+
+/// OpInterfaces' methods implemented by IterateOp.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+ return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+ return getRegion().getArguments().take_back(getNumRegionIterArgs());
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+ return cast<sparse_tensor::YieldOp>(
+ getRegion().getBlocks().front().getTerminator())
+ .getResultsMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself.
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ // It is possible for loop not to enter the body.
+ regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3fa696e1600a9..eb0dc01be25b9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+ -> !sparse_tensor.iter_space<#COO, lvls = 0 to 2>
return
}
@@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1092,5 +1095,63 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 2>
return
}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ // expected-error @+1 {{'sparse_tensor.iterate' op
diff erent number of region iter_args and yielded values: 2 != 1}}
+ %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
+ sparse_tensor.yield %si : index
+ }
+ return %r1 : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// expected-note at +1 {{prior use here}}
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ // expected-error @+1 {{use of value '%i' expects
diff erent type than prior uses: 'f32' vs 'index'}}
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
+ sparse_tensor.yield %outer : f32
+ }
+ return %r1 : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have
diff erent type: 'index' != 'f32'}}
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
+ %y = arith.constant 1.0 : f32
+ sparse_tensor.yield %y : f32
+ }
+ return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d34071279e512..bce0b41a99828 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -758,8 +758,37 @@ func.func @sparse_has_runtime() -> i1 {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
-> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
// Extracting the iteration space for the first level needs no parent iterator.
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// Extracting the iteration space for the second level needs a parent iterator.
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_iterate(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME: %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: index) -> index {
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK: sparse_tensor.yield %[[VAL_7]] : index
+// CHECK: }
+// CHECK: return %[[VAL_4]] : index
+// CHECK: }
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ sparse_tensor.yield %outer : index
+ }
+ return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
new file mode 100644
index 0000000000000..f70fab3b7251d
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : dense,
+ j : compressed
+ )
+}>
+
+// Make sure that pure instructions are hoisted outside the loop.
+//
+// CHECK: sparse_tensor.values
+// CHECK: sparse_tensor.positions
+// CHECK: sparse_tensor.coordinate
+// CHECK: sparse_tensor.iterate
+func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+ -> !sparse_tensor.iter_space<#CSR, lvls = 0>
+ sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
+ %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
+ %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+ %2 = sparse_tensor.coordinates %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+ "test.op"(%0, %1, %2) : (memref<?xf64>, memref<?xindex>, memref<?xindex>) -> ()
+ }
+
+ return
+}
More information about the Mlir-commits
mailing list