[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.crd_translate operation (PR #69630)

Yinying Li llvmlistbot at llvm.org
Thu Oct 19 12:27:41 PDT 2023


================
@@ -1160,6 +1160,60 @@ bool ConvertOp::needsExtraSort() {
   return true;
 }
 
+LogicalResult CrdTranslateOp::verify() {
+  uint64_t inRank = getOracle().getLvlRank();
+  uint64_t outRank = getOracle().getDimRank();
+
+  if (getDirection() == CrdTransDirectionKind::dim2lvl)
+    std::swap(inRank, outRank);
+
+  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
+    return emitError("Coordinate rank mismatch with encoding");
+
+  return success();
+}
+
+LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
+                                   SmallVectorImpl<OpFoldResult> &results) {
+  if (getOracle().isPermutation()) {
+    AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
+                         ? getOracle().getDimToLvl()
+                         : getOracle().getLvlToDim();
+    for (AffineExpr exp : perm.getResults())
+      results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+    return success();
+  }
+
+  // Fuse dim2lvl/lvl2dim pairs.
+  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
+  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
+                   return v.getDefiningOp() == def;
+                 });
+  if (!sameDef)
+    return failure();
+
+  bool oppositeDir = def.getDirection() != getDirection();
+  bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
+  bool sameCount = def.getNumResults() == getInCrds().size();
+  if (!oppositeDir || !sameOracle || !sameCount)
+    return failure();
+
+  // The definition produce the coordinate in the same order as the input
----------------
yinying-lisa-li wrote:

nit: produces and coordinates

https://github.com/llvm/llvm-project/pull/69630


More information about the Mlir-commits mailing list