[Mlir-commits] [flang] [mlir] [mlir] add option to print SSA IDs using `NameLoc`s as prefixes (PR #119996)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Dec 15 15:13:43 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
This PR adds an `AsmPrinter` option `-mlir-use-nameloc-as-prefix` which uses trailing `NameLoc`s, if the source IR provides them, as prefixes when printing SSA IDs. **Note**: this PR only changes `AsmPrinter`.
For example:
```mlir
%1 = memref.load %0[] : memref<i32> loc("alice")
```
prints
```mlir
%alice = memref.load %0[] : memref<i32>
```
Currently single/multiple results and bb args are handled:
```mlir
%1:2 = call @<!-- -->make_two_results() : () -> (index, index) loc("bar")
%2:2 = scf.while (%arg1 = %1#<!-- -->0, %arg2 = %1#<!-- -->0) : (index, index) -> (index, index) {
%3 = arith.cmpi slt, %arg1, %arg2 : index
scf.condition(%3) %arg1, %arg2 : index, index
} do {
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
%c1, %c2 = func.call @<!-- -->make_two_results() : () -> (index, index) loc("harriet")
scf.yield %c2, %c2 : index, index
} loc("kevin")
```
becomes
```mlir
%bar:2 = call @<!-- -->make_two_results() : () -> (index, index)
%kevin:2 = scf.while (%arg1 = %bar#<!-- -->0, %arg2 = %bar#<!-- -->0) : (index, index) -> (index, index) {
%0 = arith.cmpi slt, %arg1, %arg2 : index
scf.condition(%0) %arg1, %arg2 : index, index
} do {
^bb0(%alice: index, %bob: index):
%harriet:2 = func.call @<!-- -->make_two_results() : () -> (index, index)
scf.yield %harriet#<!-- -->1, %harriet#<!-- -->1 : index, index
}
```
The changes here are also compatible with `OpAsmOpInterface`. Note though, if an op implements `getAsmBlockArgumentNames` but not `getAsmResultNames` (like [`linalg.generic`](https://github.com/makslevental/llvm-project/blob/288f05f63e5f3246657aca9561d75b2aa02cb6f5/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td#L56)) then the affixed `loc("...")` will not be processed because [`setResultNameFn`](https://github.com/llvm/llvm-project/blob/d61562c967b3a6f0c34e00c90155d16580288d8a/mlir/lib/IR/AsmPrinter.cpp#L1601) is never called.
### Testing
Besides the added lit test, I "stress tested" this by turn it on by default and checking if anything broke ([this commit](https://github.com/llvm/llvm-project/pull/119996/commits/bbca902568bfb047ef67a96dad631314669fd6f9)), then fixing my bugs and stress testing again (this [commit](https://github.com/llvm/llvm-project/pull/119996/commits/ca084f95b2c1a424392049bc4a8555f4f23d83fe) and this [passing test](https://buildkite.com/llvm-project/github-pull-requests/builds/129062#<!-- -->0193c8fd-9284-4c5f-b4c0-ec337d75f7be), Windows test fail there was a flake).
---
Full diff: https://github.com/llvm/llvm-project/pull/119996.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/OperationSupport.h (+7)
- (modified) mlir/lib/IR/AsmPrinter.cpp (+57-6)
- (added) mlir/test/IR/print-use-nameloc-as-prefix.mlir (+105)
``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
/// Return if printer should use unique SSA IDs.
bool shouldPrintUniqueSSAIDs() const;
+ /// Return if the printer should use NameLocs as prefixes when printing SSA
+ /// IDs
+ bool shouldUseNameLocAsPrefix() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
/// Print unique SSA IDs for values, block arguments and naming conflicts
bool printUniqueSSAIDsFlag : 1;
+
+ /// Print SSA IDs using NameLocs as prefixes
+ bool useNameLocAsPrefix : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..1a81414b358d06 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
"and naming conflicts across all regions")};
+
+ llvm::cl::opt<bool> useNameLocAsPrefix{
+ "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+ llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
};
} // namespace
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+ printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+ useNameLocAsPrefix(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+ useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
}
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+ return useNameLocAsPrefix;
+}
+
//===----------------------------------------------------------------------===//
// NewLineCounter
//===----------------------------------------------------------------------===//
@@ -1511,13 +1523,30 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
- setValueName(arg, name);
+ if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, name);
+ }
};
+ bool alreadySetNames = false;
if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
- if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+ if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ alreadySetNames = true;
+ }
+ }
+ }
+
+ if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
+ for (BlockArgument arg : region.getArguments()) {
+ if (isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setBlockArgNameFn(arg, nameLoc.getName());
+ }
}
}
@@ -1553,7 +1582,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
}
- setValueName(arg, specialName.str());
+ if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, specialName.str());
+ }
}
// Number the operations in this block.
@@ -1567,7 +1601,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
- setValueName(result, name);
+ if (printerFlags.shouldUseNameLocAsPrefix() &&
+ isa<NameLoc>(result.getLoc())) {
+ auto nameLoc = cast<NameLoc>(result.getLoc());
+ setValueName(result, nameLoc.getName());
+ } else {
+ setValueName(result, name);
+ }
// Record the result number for groups not anchored at 0.
if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1589,14 +1629,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
blockNames[block] = {-1, name};
};
+ bool alreadySetNames = false;
if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
+ alreadySetNames = true;
}
}
unsigned numResults = op.getNumResults();
+ if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
+ numResults > 0) {
+ Value resultBegin = op.getResult(0);
+ if (isa<NameLoc>(resultBegin.getLoc())) {
+ auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
+ setResultNameFn(resultBegin, nameLoc.getName());
+ }
+ }
+
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
if (printerFlags.shouldPrintValueUsers()) {
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..fb555d9708ee86
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ // CHECK: %alice_0
+ %2 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+ // CHECK: %alice = memref.load %foo
+ %1 = memref.load %arg0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+ // CHECK: %foo:2
+ %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+ // CHECK: %bar:2
+ %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+ // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+ %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+ %6 = arith.cmpi slt, %arg1, %arg2 : index
+ scf.condition(%6) %arg1, %arg2 : index, index
+ } do {
+ // CHECK: ^bb0(%alice: index, %bob: index)
+ ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+ %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+ // CHECK: scf.yield %harriet#1, %harriet#1
+ scf.yield %c2, %c2 : index, index
+ } loc("kevin")
+ return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+ iterator_types = ["parallel"],
+ indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %c0
+ %0 = arith.constant 0 : index
+ // CHECK: %foo
+ %1 = arith.constant 1 : index loc("foo")
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ linalg.yield %a, %a : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+ ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+ // CHECK: linalg.yield %alice, %steve
+ linalg.yield %b, %c : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg2 = %c0 to %c4 step %c1 {
+ // CHECK-PASS-PRESERVE: %foo = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo
+ // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo_1
+ %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+ memref.store %0, %arg1[%arg2] : memref<4xf32>
+ }
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/119996
More information about the Mlir-commits
mailing list