[clang] [mlir][sparse] Print new syntax (PR #68130)

Peiming Liu via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 3 16:24:45 PDT 2023


================
@@ -586,30 +586,56 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
-  // Print the struct-like storage in dictionary fashion.
-  printer << "<{ lvlTypes = [ ";
-  llvm::interleaveComma(getLvlTypes(), printer, [&](DimLevelType dlt) {
-    printer << "\"" << toMLIRString(dlt) << "\"";
-  });
-  printer << " ]";
+  auto map = static_cast<AffineMap>(getDimToLvl());
+  auto lvlTypes = getLvlTypes();
+  // Empty affine map indicates identity map
+  if (!map) {
+    map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
+  }
+  // Modified version of AsmPrinter::Impl::printAffineMap.
+  printer << "<{ map = ";
+  // Symbolic identifiers.
+  if (map.getNumSymbols() != 0) {
+    printer << '[';
+    for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
+      printer << 's' << i << ", ";
+    if (map.getNumSymbols() >= 1)
+      printer << 's' << map.getNumSymbols() - 1;
+    printer << ']';
+  }
+  // Dimension identifiers.
+  printer << '(';
+  auto dimSlices = getDimSlices();
+  if (!dimSlices.empty()) {
+    for (unsigned i = 0; i < map.getNumDims() - 1; ++i)
+      printer << 'd' << i << " : " << dimSlices[i] << ", ";
+    if (map.getNumDims() >= 1)
+      printer << 'd' << map.getNumDims() - 1 << " : "
+              << dimSlices[map.getNumDims() - 1];
+  } else {
+    for (unsigned i = 0; i < map.getNumDims() - 1; ++i)
+      printer << 'd' << i << ", ";
+    if (map.getNumDims() >= 1)
+      printer << 'd' << map.getNumDims() - 1;
+  }
+  printer << ')';
+  // Level format and properties.
+  printer << " -> (";
+  for (unsigned i = 0; i < map.getNumResults() - 1; ++i) {
+    map.getResult(i).print(printer.getStream());
+    printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+  }
+  if (map.getNumResults() >= 1) {
+    auto lastIndex = map.getNumResults() - 1;
+    map.getResult(lastIndex).print(printer.getStream());
+    printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+  }
+  printer << ')';
----------------
PeimingLiu wrote:

I would suggest you break these into smaller functions.

https://github.com/llvm/llvm-project/pull/68130


More information about the cfe-commits mailing list