[Mlir-commits] [mlir] f5b974b - [mlir][sparse] Adding `{Var, DimLvlExpr, DimSpec, LvlSpec, DimLvlMap}::str` methods

Aart Bik llvmlistbot at llvm.org
Tue Aug 22 20:31:30 PDT 2023


Author: wren romano
Date: 2023-08-22T20:31:15-07:00
New Revision: f5b974b7835a27de7c5d6142e935f2f95da997a3

URL: https://github.com/llvm/llvm-project/commit/f5b974b7835a27de7c5d6142e935f2f95da997a3
DIFF: https://github.com/llvm/llvm-project/commit/f5b974b7835a27de7c5d6142e935f2f95da997a3.diff

LOG: [mlir][sparse] Adding `{Var,DimLvlExpr,DimSpec,LvlSpec,DimLvlMap}::str` methods

These methods are needed for use with `Diagnostic::operator<<` etc.

The definitions follow the pattern of `Diagnostic::str` by simply wrapping the underlying `print(raw_ostream)` method.  Although there is some overhead for constructing the `std::string`, this seems like the overall most-efficient option: since this overhead only occurs on the error path (under the current intended usage).  An alternative approach would be to have one method construct a `Twine` directly, and then have the print method pass the twine to the stream; however, that would mean introducing the overhead of twine construction on the common/happy path of simply printing things out.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D157643

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 792626b45283ea..e1eaa8a4d3f9c5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -98,6 +98,12 @@ void DimLvlExpr::dump() const {
   print(llvm::errs());
   llvm::errs() << "\n";
 }
+std::string DimLvlExpr::str() const {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  print(os);
+  return os.str();
+}
 void DimLvlExpr::print(AsmPrinter &printer) const {
   print(printer.getStream());
 }
@@ -220,6 +226,12 @@ void DimSpec::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
 }
+std::string DimSpec::str(bool wantElision) const {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  print(os, wantElision);
+  return os.str();
+}
 void DimSpec::print(AsmPrinter &printer, bool wantElision) const {
   print(printer.getStream(), wantElision);
 }
@@ -260,6 +272,12 @@ void LvlSpec::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
 }
+std::string LvlSpec::str(bool wantElision) const {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  print(os, wantElision);
+  return os.str();
+}
 void LvlSpec::print(AsmPrinter &printer, bool wantElision) const {
   print(printer.getStream(), wantElision);
 }
@@ -345,6 +363,12 @@ void DimLvlMap::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
 }
+std::string DimLvlMap::str(bool wantElision) const {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  print(os, wantElision);
+  return os.str();
+}
 void DimLvlMap::print(AsmPrinter &printer, bool wantElision) const {
   print(printer.getStream(), wantElision);
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 040d7ea919a642..2d02eed2cf9972 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -122,6 +122,7 @@ class DimLvlExpr {
   /// with respect to the given ranks.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
 
+  std::string str() const;
   void print(llvm::raw_ostream &os) const;
   void print(AsmPrinter &printer) const;
   void dump() const;
@@ -251,6 +252,7 @@ class DimSpec final {
   bool isFunctionOf(VarSet const &vars) const;
   void getFreeVars(VarSet &vars) const;
 
+  std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
@@ -306,6 +308,7 @@ class LvlSpec final {
   bool isFunctionOf(VarSet const &vars) const;
   void getFreeVars(VarSet &vars) const;
 
+  std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
@@ -339,6 +342,7 @@ class DimLvlMap final {
   AffineMap getDimToLvlMap(MLIRContext *context) const;
   AffineMap getLvlToDimMap(MLIRContext *context) const;
 
+  std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 15dae63649ede3..3b00e17657f1f9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -25,6 +25,13 @@ static constexpr const VarKind everyVarKind[] = {
 // `Var` implementation.
 //===----------------------------------------------------------------------===//
 
+std::string Var::str() const {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  print(os);
+  return os.str();
+}
+
 void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
 
 void Var::print(llvm::raw_ostream &os) const {

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 18c68dd5e1118e..a488b3ea2d56ba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -197,6 +197,7 @@ class Var {
   template <typename U>
   constexpr std::optional<U> dyn_cast() const;
 
+  std::string str() const;
   void print(llvm::raw_ostream &os) const;
   void print(AsmPrinter &printer) const;
   void dump() const;


        


More information about the Mlir-commits mailing list