[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (PR #101100)

Aart Bik llvmlistbot at llvm.org
Tue Jul 30 11:10:17 PDT 2024


================
@@ -1644,6 +1644,123 @@ def IterateOp : SparseTensor_Op<"iterate",
   let hasCustomAssemblyFormat = 1;
 }
 
+def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
+    [AttrSizedOperandSegments,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
+     RecursiveMemoryEffects]> {
+  let summary = "CoIterates over a set of sparse iteration spaces";
+  let description = [{
+      The `sparse_tensor.coiterate` operation represents a loop (nest) over
+      the a set of iteration spaces.
+      The operation can have multiple regions, with each of them defining a
+      case to compute a result at the current iterations. The case condition
+      is defined solely based on the pattern of specified iterators.
+      For example:
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#COO, lvls = 0>)
+           -> index
+      case %it1, _ {
+        // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
+      }
+      case %it1, %it2 {
+        // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
+      }
+      ```
+
+      `sparse_tensor.coiterate` can also operate on loop-carried variables.
+      It returns the final values after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values.
+      Each operation region has variadic arguments for specified (used), one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration, followed by a list of arguments for iterators.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      The results of an `sparse_tensor.coiterate` hold the final values after
+      the last iteration. If the `sparse_tensor.coiterate` defines any values,
+      a yield must be explicitly present in every region defined in the operation.
+      The number and types of the `sparse_tensor.coiterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A `sparse_tensor.coiterate` example that does elementwise addition between two
+      sparse vectors.
+
+
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#CSR, lvls = 0>)
+           -> tensor<?xindex, #CSR>
+      case %it1, _ {
+         // v = v1 + 0 = v1
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case _, %it2 {
+         // v = v2 + 0 = v2
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case %it1, %it2 {
+         // v = v1 + v2
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %v = arith.addi %v1, %v2 : index
+         %yield = sparse_tensor.insert %v into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      ```
+  }];
+
+  let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
+                       Variadic<AnyType>:$initArgs,
+                       I64BitSetAttr:$crdUsedLvls,
+                       I64BitSetArrayAttr:$cases);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
+                 getIterSpaces().front().getType())
+          .getSpaceDim();
+    }
+    I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
+      return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
+                           .getValue().getZExtValue());
+    }
+    // The block arguments starts with referenced coordinates, follows by
+    // user-provided iteration arguments and ends with iterators.
+    Block::BlockArgListType getCrds(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_front(getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs(unsigned regionIdx) {
+      return getInitArgs().size();
+    }
+    Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+    }
+    Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_back(getRegionDefinedSpace(regionIdx).count());
+    }
+  }];
+
+  // TODO:
+  // let hasVerifier = 1;
----------------
aartbik wrote:

don't forget to add invalid.mlir tests as well when you add the verifiers

https://github.com/llvm/llvm-project/pull/101100


More information about the Mlir-commits mailing list