[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.crd_translate operation (PR #69630)
Yinying Li
llvmlistbot at llvm.org
Thu Oct 19 12:27:41 PDT 2023
================
@@ -1160,6 +1160,60 @@ bool ConvertOp::needsExtraSort() {
return true;
}
+LogicalResult CrdTranslateOp::verify() {
+ uint64_t inRank = getOracle().getLvlRank();
+ uint64_t outRank = getOracle().getDimRank();
+
+ if (getDirection() == CrdTransDirectionKind::dim2lvl)
+ std::swap(inRank, outRank);
+
+ if (inRank != getInCrds().size() || outRank != getOutCrds().size())
+ return emitError("Coordinate rank mismatch with encoding");
+
+ return success();
+}
+
+LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ if (getOracle().isPermutation()) {
+ AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
+ ? getOracle().getDimToLvl()
+ : getOracle().getLvlToDim();
+ for (AffineExpr exp : perm.getResults())
+ results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+ return success();
+ }
+
+ // Fuse dim2lvl/lvl2dim pairs.
+ auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
+ bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
+ return v.getDefiningOp() == def;
+ });
+ if (!sameDef)
+ return failure();
+
+ bool oppositeDir = def.getDirection() != getDirection();
+ bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
+ bool sameCount = def.getNumResults() == getInCrds().size();
+ if (!oppositeDir || !sameOracle || !sameCount)
+ return failure();
+
+ // The definition produce the coordinate in the same order as the input
----------------
yinying-lisa-li wrote:
nit: produces and coordinates
https://github.com/llvm/llvm-project/pull/69630
More information about the Mlir-commits
mailing list