[Mlir-commits] [mlir] f539e00 - [mlir] add option to print SSA IDs using `NameLoc`s as prefixes (#119996)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 17 07:59:16 PST 2024


Author: Maksim Levental
Date: 2024-12-17T07:59:11-08:00
New Revision: f539e00c702b4e5732d76e093c2d909fd8702683

URL: https://github.com/llvm/llvm-project/commit/f539e00c702b4e5732d76e093c2d909fd8702683
DIFF: https://github.com/llvm/llvm-project/commit/f539e00c702b4e5732d76e093c2d909fd8702683.diff

LOG: [mlir] add option to print SSA IDs using `NameLoc`s as prefixes (#119996)

This PR adds an `AsmPrinter` option `-mlir-use-nameloc-as-prefix` which
uses trailing `NameLoc`s, if the source IR provides them, as prefixes
when printing SSA IDs.

Added: 
    mlir/test/IR/print-use-nameloc-as-prefix.mlir

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/AsmPrinter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 7b19a7bf6bf471..9fb0aaab064619 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -1398,6 +1398,27 @@ $ tree /tmp/pipeline_output
 │   │   ├── 1_1_pass4.mlir
 ```
 
+*   `mlir-use-nameloc-as-prefix`
+    * If your source IR has named locations (`loc("named_location")"`) then passing this flag will use those
+      names (`named_location`) to prefix the corresponding SSA identifiers:
+    
+      ```mlir
+      %1 = memref.load %0[] : memref<i32> loc("alice")  
+      %2 = memref.load %0[] : memref<i32> loc("bob")
+      %3 = memref.load %0[] : memref<i32> loc("bob")
+      ```
+      
+      will print 
+    
+      ```mlir
+      %alice = memref.load %0[] : memref<i32>
+      %bob = memref.load %0[] : memref<i32>
+      %bob_0 = memref.load %0[] : memref<i32>
+      ```
+      
+      These names will also be preserved through passes to newly created operations if using the appropriate location.
+      
+
 ## Crash and Failure Reproduction
 
 The [pass manager](#pass-manager) in MLIR contains a builtin mechanism to

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
   /// Return if printer should use unique SSA IDs.
   bool shouldPrintUniqueSSAIDs() const;
 
+  /// Return if the printer should use NameLocs as prefixes when printing SSA
+  /// IDs
+  bool shouldUseNameLocAsPrefix() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
 
   /// Print unique SSA IDs for values, block arguments and naming conflicts
   bool printUniqueSSAIDsFlag : 1;
+
+  /// Print SSA IDs using NameLocs as prefixes
+  bool useNameLocAsPrefix : 1;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..99b7abe7db1f94 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
   return parseCommaSeparatedList(
       [&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
       "mlir-print-unique-ssa-ids", llvm::cl::init(false),
       llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
                      "and naming conflicts across all regions")};
+
+  llvm::cl::opt<bool> useNameLocAsPrefix{
+      "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+      llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
 };
 } // namespace
 
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
       printGenericOpFormFlag(false), skipRegionsFlag(false),
       assumeVerifiedFlag(false), printLocalScope(false),
-      printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+      printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+      useNameLocAsPrefix(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
   skipRegionsFlag = clOptions->skipRegionsOpt;
   printValueUsersFlag = clOptions->printValueUsers;
   printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+  useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
 }
 
 /// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
   return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
 }
 
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs.
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+  return useNameLocAsPrefix;
+}
+
 //===----------------------------------------------------------------------===//
 // NewLineCounter
 //===----------------------------------------------------------------------===//
@@ -1506,11 +1518,22 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
   }
 }
 
+namespace {
+/// Try to get value name from value's location, fallback to `name`.
+StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
+  if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>())
+    return maybeNameLoc.getName();
+  return name;
+}
+} // namespace
+
 void SSANameState::numberValuesInRegion(Region &region) {
   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
+    if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
+      name = maybeGetValueNameFromLoc(arg, name);
     setValueName(arg, name);
   };
 
@@ -1553,7 +1576,10 @@ void SSANameState::numberValuesInBlock(Block &block) {
       specialNameBuffer.resize(strlen("arg"));
       specialName << nextArgumentID++;
     }
-    setValueName(arg, specialName.str());
+    StringRef specialNameStr = specialName.str();
+    if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
+      specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr);
+    setValueName(arg, specialNameStr);
   }
 
   // Number the operations in this block.
@@ -1567,6 +1593,8 @@ void SSANameState::numberValuesInOp(Operation &op) {
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+    if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
+      name = maybeGetValueNameFromLoc(result, name);
     setValueName(result, name);
 
     // Record the result number for groups not anchored at 0.
@@ -1607,6 +1635,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
+  if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) {
+    if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) {
+      setValueName(resultBegin, nameLoc.getName());
+    }
+  }
+
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
     ++nextValueID;

diff  --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..ddee8aed5586cf
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice = memref.load
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice = memref.load
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  // CHECK: %alice_0 = memref.load
+  %2 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+  // CHECK: %alice = memref.load %foo
+  %1 = memref.load %arg0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+  // CHECK: %foo:2 = call @make_two_results
+  %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+  // CHECK: %bar:2 = call @make_two_results
+  %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+  // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+  %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+    %6 = arith.cmpi slt, %arg1, %arg2 : index
+    scf.condition(%6) %arg1, %arg2 : index, index
+  } do {
+  // CHECK: ^bb0(%alice: index, %bob: index)
+  ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+    %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+    // CHECK: scf.yield %harriet#1, %harriet#1
+    scf.yield %c2, %c2 : index, index
+  } loc("kevin")
+  return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+  iterator_types = ["parallel"],
+  indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %c0 = arith.constant
+  %0 = arith.constant 0 : index
+  // CHECK: %foo = arith.constant
+  %1 = arith.constant 1 : index loc("foo")
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      linalg.yield %a, %a : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+    ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+      // CHECK: linalg.yield %alice, %steve
+      linalg.yield %b, %c : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg2 = %c0 to %c4 step %c1 {
+    // CHECK-PASS-PRESERVE: %foo = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo
+    // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo_1
+    %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+    memref.store %0, %arg1[%arg2] : memref<4xf32>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list