[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (PR #101100)
Aart Bik
llvmlistbot at llvm.org
Tue Jul 30 11:10:17 PDT 2024
================
@@ -1644,6 +1644,123 @@ def IterateOp : SparseTensor_Op<"iterate",
let hasCustomAssemblyFormat = 1;
}
+def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
+ [AttrSizedOperandSegments,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
+ RecursiveMemoryEffects]> {
+ let summary = "CoIterates over a set of sparse iteration spaces";
+ let description = [{
+ The `sparse_tensor.coiterate` operation represents a loop (nest) over
+ the a set of iteration spaces.
+ The operation can have multiple regions, with each of them defining a
+ case to compute a result at the current iterations. The case condition
+ is defined solely based on the pattern of specified iterators.
+ For example:
+ ```mlir
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+ !sparse_tensor.iter_space<#COO, lvls = 0>)
+ -> index
+ case %it1, _ {
+ // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
+ }
+ case %it1, %it2 {
+ // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
+ }
+ ```
+
+ `sparse_tensor.coiterate` can also operate on loop-carried variables.
+ It returns the final values after loop termination.
+ The initial values of the variables are passed as additional SSA operands
+ to the iterator SSA value and used coordinate SSA values.
+ Each operation region has variadic arguments for specified (used), one argument
+ for each loop-carried variable, representing the value of the variable
+ at the current iteration, followed by a list of arguments for iterators.
+ The body region must contain exactly one block that terminates with
+ `sparse_tensor.yield`.
+
+ The results of an `sparse_tensor.coiterate` hold the final values after
+ the last iteration. If the `sparse_tensor.coiterate` defines any values,
+ a yield must be explicitly present in every region defined in the operation.
+ The number and types of the `sparse_tensor.coiterate` results must match
+ the initial values in the iter_args binding and the yield operands.
+
+
+ A `sparse_tensor.coiterate` example that does elementwise addition between two
+ sparse vectors.
+
+
+ ```mlir
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+ !sparse_tensor.iter_space<#CSR, lvls = 0>)
+ -> tensor<?xindex, #CSR>
+ case %it1, _ {
+ // v = v1 + 0 = v1
+ %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+ %yield = sparse_tensor.insert %v1 into %arg[%coord]
+ sparse_tensor.yield %yield
+ }
+ case _, %it2 {
+ // v = v2 + 0 = v2
+ %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+ %yield = sparse_tensor.insert %v1 into %arg[%coord]
+ sparse_tensor.yield %yield
+ }
+ case %it1, %it2 {
+ // v = v1 + v2
+ %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+ %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+ %v = arith.addi %v1, %v2 : index
+ %yield = sparse_tensor.insert %v into %arg[%coord]
+ sparse_tensor.yield %yield
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
+ Variadic<AnyType>:$initArgs,
+ I64BitSetAttr:$crdUsedLvls,
+ I64BitSetArrayAttr:$cases);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
+ getIterSpaces().front().getType())
+ .getSpaceDim();
+ }
+ I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
+ return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
+ .getValue().getZExtValue());
+ }
+ // The block arguments starts with referenced coordinates, follows by
+ // user-provided iteration arguments and ends with iterators.
+ Block::BlockArgListType getCrds(unsigned regionIdx) {
+ return getRegion(regionIdx).getArguments()
+ .take_front(getCrdUsedLvls().count());
+ }
+ unsigned getNumRegionIterArgs(unsigned regionIdx) {
+ return getInitArgs().size();
+ }
+ Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
+ return getRegion(regionIdx).getArguments()
+ .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+ }
+ Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
+ return getRegion(regionIdx).getArguments()
+ .take_back(getRegionDefinedSpace(regionIdx).count());
+ }
+ }];
+
+ // TODO:
+ // let hasVerifier = 1;
----------------
aartbik wrote:
don't forget to add invalid.mlir tests as well when you add the verifiers
https://github.com/llvm/llvm-project/pull/101100
More information about the Mlir-commits
mailing list