[Mlir-commits] [flang] [mlir] [mlir] add option to print SSA IDs using `NameLoc`s as prefixes (PR #119996)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 15 15:13:43 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This PR adds an `AsmPrinter` option `-mlir-use-nameloc-as-prefix` which uses trailing `NameLoc`s, if the source IR provides them, as prefixes when printing SSA IDs. **Note**: this PR only changes `AsmPrinter`.

For example:

```mlir
%1 = memref.load %0[] : memref<i32> loc("alice")
```

prints

```mlir
%alice = memref.load %0[] : memref<i32>
```

Currently single/multiple results and bb args are handled:

```mlir
%1:2 = call @<!-- -->make_two_results() : () -> (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#<!-- -->0, %arg2 = %1#<!-- -->0) : (index, index) -> (index, index) {
  %3 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
  %c1, %c2 = func.call @<!-- -->make_two_results() : () -> (index, index) loc("harriet")
  scf.yield %c2, %c2 : index, index
} loc("kevin")
```

becomes

```mlir
%bar:2 = call @<!-- -->make_two_results() : () -> (index, index)
%kevin:2 = scf.while (%arg1 = %bar#<!-- -->0, %arg2 = %bar#<!-- -->0) : (index, index) -> (index, index) {
  %0 = arith.cmpi slt, %arg1, %arg2 : index
  scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
  %harriet:2 = func.call @<!-- -->make_two_results() : () -> (index, index)
  scf.yield %harriet#<!-- -->1, %harriet#<!-- -->1 : index, index
}
```

The changes here are also compatible with `OpAsmOpInterface`. Note though, if an op implements `getAsmBlockArgumentNames` but not `getAsmResultNames` (like [`linalg.generic`](https://github.com/makslevental/llvm-project/blob/288f05f63e5f3246657aca9561d75b2aa02cb6f5/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td#L56)) then the affixed `loc("...")` will not be processed because [`setResultNameFn`](https://github.com/llvm/llvm-project/blob/d61562c967b3a6f0c34e00c90155d16580288d8a/mlir/lib/IR/AsmPrinter.cpp#L1601) is never called.

### Testing

Besides the added lit test, I "stress tested" this by turn it on by default and checking if anything broke ([this commit](https://github.com/llvm/llvm-project/pull/119996/commits/bbca902568bfb047ef67a96dad631314669fd6f9)), then fixing my bugs and stress testing again (this [commit](https://github.com/llvm/llvm-project/pull/119996/commits/ca084f95b2c1a424392049bc4a8555f4f23d83fe) and this [passing test](https://buildkite.com/llvm-project/github-pull-requests/builds/129062#<!-- -->0193c8fd-9284-4c5f-b4c0-ec337d75f7be), Windows test fail there was a flake).

---
Full diff: https://github.com/llvm/llvm-project/pull/119996.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/OperationSupport.h (+7) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+57-6) 
- (added) mlir/test/IR/print-use-nameloc-as-prefix.mlir (+105) 


``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
   /// Return if printer should use unique SSA IDs.
   bool shouldPrintUniqueSSAIDs() const;
 
+  /// Return if the printer should use NameLocs as prefixes when printing SSA
+  /// IDs
+  bool shouldUseNameLocAsPrefix() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
 
   /// Print unique SSA IDs for values, block arguments and naming conflicts
   bool printUniqueSSAIDsFlag : 1;
+
+  /// Print SSA IDs using NameLocs as prefixes
+  bool useNameLocAsPrefix : 1;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..1a81414b358d06 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
   return parseCommaSeparatedList(
       [&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
       "mlir-print-unique-ssa-ids", llvm::cl::init(false),
       llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
                      "and naming conflicts across all regions")};
+
+  llvm::cl::opt<bool> useNameLocAsPrefix{
+      "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+      llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
 };
 } // namespace
 
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
       printGenericOpFormFlag(false), skipRegionsFlag(false),
       assumeVerifiedFlag(false), printLocalScope(false),
-      printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+      printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+      useNameLocAsPrefix(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
   skipRegionsFlag = clOptions->skipRegionsOpt;
   printValueUsersFlag = clOptions->printValueUsers;
   printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+  useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
 }
 
 /// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
   return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
 }
 
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+  return useNameLocAsPrefix;
+}
+
 //===----------------------------------------------------------------------===//
 // NewLineCounter
 //===----------------------------------------------------------------------===//
@@ -1511,13 +1523,30 @@ void SSANameState::numberValuesInRegion(Region &region) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
-    setValueName(arg, name);
+    if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+      auto nameLoc = cast<NameLoc>(arg.getLoc());
+      setValueName(arg, nameLoc.getName());
+    } else {
+      setValueName(arg, name);
+    }
   };
 
+  bool alreadySetNames = false;
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (Operation *op = region.getParentOp()) {
-      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+        alreadySetNames = true;
+      }
+    }
+  }
+
+  if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
+    for (BlockArgument arg : region.getArguments()) {
+      if (isa<NameLoc>(arg.getLoc())) {
+        auto nameLoc = cast<NameLoc>(arg.getLoc());
+        setBlockArgNameFn(arg, nameLoc.getName());
+      }
     }
   }
 
@@ -1553,7 +1582,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
       specialNameBuffer.resize(strlen("arg"));
       specialName << nextArgumentID++;
     }
-    setValueName(arg, specialName.str());
+    if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+      auto nameLoc = cast<NameLoc>(arg.getLoc());
+      setValueName(arg, nameLoc.getName());
+    } else {
+      setValueName(arg, specialName.str());
+    }
   }
 
   // Number the operations in this block.
@@ -1567,7 +1601,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
-    setValueName(result, name);
+    if (printerFlags.shouldUseNameLocAsPrefix() &&
+        isa<NameLoc>(result.getLoc())) {
+      auto nameLoc = cast<NameLoc>(result.getLoc());
+      setValueName(result, nameLoc.getName());
+    } else {
+      setValueName(result, name);
+    }
 
     // Record the result number for groups not anchored at 0.
     if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1589,14 +1629,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
     blockNames[block] = {-1, name};
   };
 
+  bool alreadySetNames = false;
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
       asmInterface.getAsmBlockNames(setBlockNameFn);
       asmInterface.getAsmResultNames(setResultNameFn);
+      alreadySetNames = true;
     }
   }
 
   unsigned numResults = op.getNumResults();
+  if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
+      numResults > 0) {
+    Value resultBegin = op.getResult(0);
+    if (isa<NameLoc>(resultBegin.getLoc())) {
+      auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
+      setResultNameFn(resultBegin, nameLoc.getName());
+    }
+  }
+
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
     if (printerFlags.shouldPrintValueUsers()) {
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..fb555d9708ee86
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+  %0 = memref.alloc() : memref<i32>
+  // CHECK: %alice
+  %1 = memref.load %0[] : memref<i32> loc("alice")
+  // CHECK: %alice_0
+  %2 = memref.load %0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+  // CHECK: %alice = memref.load %foo
+  %1 = memref.load %arg0[] : memref<i32> loc("alice")
+  return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+  // CHECK: %foo:2
+  %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+  // CHECK: %bar:2
+  %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+  // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+  %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+    %6 = arith.cmpi slt, %arg1, %arg2 : index
+    scf.condition(%6) %arg1, %arg2 : index, index
+  } do {
+  // CHECK: ^bb0(%alice: index, %bob: index)
+  ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+    %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+    // CHECK: scf.yield %harriet#1, %harriet#1
+    scf.yield %c2, %c2 : index, index
+  } loc("kevin")
+  return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+  iterator_types = ["parallel"],
+  indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %c0
+  %0 = arith.constant 0 : index
+  // CHECK: %foo
+  %1 = arith.constant 1 : index loc("foo")
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      linalg.yield %a, %a : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+    ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+      // CHECK: linalg.yield %alice, %steve
+      linalg.yield %b, %c : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg2 = %c0 to %c4 step %c1 {
+    // CHECK-PASS-PRESERVE: %foo = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo
+    // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+    // CHECK-PASS-PRESERVE: memref.store %foo_1
+    %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+    memref.store %0, %arg1[%arg2] : memref<4xf32>
+  }
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/119996


More information about the Mlir-commits mailing list