[Mlir-commits] [mlir] f539e00 - [mlir] add option to print SSA IDs using `NameLoc`s as prefixes (#119996)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 17 07:59:16 PST 2024
Author: Maksim Levental
Date: 2024-12-17T07:59:11-08:00
New Revision: f539e00c702b4e5732d76e093c2d909fd8702683
URL: https://github.com/llvm/llvm-project/commit/f539e00c702b4e5732d76e093c2d909fd8702683
DIFF: https://github.com/llvm/llvm-project/commit/f539e00c702b4e5732d76e093c2d909fd8702683.diff
LOG: [mlir] add option to print SSA IDs using `NameLoc`s as prefixes (#119996)
This PR adds an `AsmPrinter` option `-mlir-use-nameloc-as-prefix` which
uses trailing `NameLoc`s, if the source IR provides them, as prefixes
when printing SSA IDs.
Added:
mlir/test/IR/print-use-nameloc-as-prefix.mlir
Modified:
mlir/docs/PassManagement.md
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/AsmPrinter.cpp
Removed:
################################################################################
diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 7b19a7bf6bf471..9fb0aaab064619 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -1398,6 +1398,27 @@ $ tree /tmp/pipeline_output
│ │ ├── 1_1_pass4.mlir
```
+* `mlir-use-nameloc-as-prefix`
+ * If your source IR has named locations (`loc("named_location")"`) then passing this flag will use those
+ names (`named_location`) to prefix the corresponding SSA identifiers:
+
+ ```mlir
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ %2 = memref.load %0[] : memref<i32> loc("bob")
+ %3 = memref.load %0[] : memref<i32> loc("bob")
+ ```
+
+ will print
+
+ ```mlir
+ %alice = memref.load %0[] : memref<i32>
+ %bob = memref.load %0[] : memref<i32>
+ %bob_0 = memref.load %0[] : memref<i32>
+ ```
+
+ These names will also be preserved through passes to newly created operations if using the appropriate location.
+
+
## Crash and Failure Reproduction
The [pass manager](#pass-manager) in MLIR contains a builtin mechanism to
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..9f2de582b03e56 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
/// Return if printer should use unique SSA IDs.
bool shouldPrintUniqueSSAIDs() const;
+ /// Return if the printer should use NameLocs as prefixes when printing SSA
+ /// IDs
+ bool shouldUseNameLocAsPrefix() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
/// Print unique SSA IDs for values, block arguments and naming conflicts
bool printUniqueSSAIDsFlag : 1;
+
+ /// Print SSA IDs using NameLocs as prefixes
+ bool useNameLocAsPrefix : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..99b7abe7db1f94 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
"and naming conflicts across all regions")};
+
+ llvm::cl::opt<bool> useNameLocAsPrefix{
+ "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
+ llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
};
} // namespace
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+ printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+ useNameLocAsPrefix(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+ useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
}
+/// Return if the printer should use NameLocs as prefixes when printing SSA IDs.
+bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
+ return useNameLocAsPrefix;
+}
+
//===----------------------------------------------------------------------===//
// NewLineCounter
//===----------------------------------------------------------------------===//
@@ -1506,11 +1518,22 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
}
}
+namespace {
+/// Try to get value name from value's location, fallback to `name`.
+StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
+ if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>())
+ return maybeNameLoc.getName();
+ return name;
+}
+} // namespace
+
void SSANameState::numberValuesInRegion(Region ®ion) {
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
+ if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
+ name = maybeGetValueNameFromLoc(arg, name);
setValueName(arg, name);
};
@@ -1553,7 +1576,10 @@ void SSANameState::numberValuesInBlock(Block &block) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
}
- setValueName(arg, specialName.str());
+ StringRef specialNameStr = specialName.str();
+ if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
+ specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr);
+ setValueName(arg, specialNameStr);
}
// Number the operations in this block.
@@ -1567,6 +1593,8 @@ void SSANameState::numberValuesInOp(Operation &op) {
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+ if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
+ name = maybeGetValueNameFromLoc(result, name);
setValueName(result, name);
// Record the result number for groups not anchored at 0.
@@ -1607,6 +1635,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
+ if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) {
+ if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) {
+ setValueName(resultBegin, nameLoc.getName());
+ }
+ }
+
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
diff --git a/mlir/test/IR/print-use-nameloc-as-prefix.mlir b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
new file mode 100644
index 00000000000000..ddee8aed5586cf
--- /dev/null
+++ b/mlir/test/IR/print-use-nameloc-as-prefix.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
+
+// CHECK-LABEL: test_basic
+func.func @test_basic() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice = memref.load
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_repeat_namelocs
+func.func @test_repeat_namelocs() {
+ %0 = memref.alloc() : memref<i32>
+ // CHECK: %alice = memref.load
+ %1 = memref.load %0[] : memref<i32> loc("alice")
+ // CHECK: %alice_0 = memref.load
+ %2 = memref.load %0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_bb_args
+func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
+ // CHECK: %alice = memref.load %foo
+ %1 = memref.load %arg0[] : memref<i32> loc("alice")
+ return
+}
+
+// -----
+
+func.func private @make_two_results() -> (index, index)
+
+// CHECK-LABEL: test_multiple_results
+func.func @test_multiple_results(%cond: i1) {
+ // CHECK: %foo:2 = call @make_two_results
+ %0:2 = call @make_two_results() : () -> (index, index) loc("foo")
+ // CHECK: %bar:2 = call @make_two_results
+ %1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
+
+ // CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
+ %5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
+ %6 = arith.cmpi slt, %arg1, %arg2 : index
+ scf.condition(%6) %arg1, %arg2 : index, index
+ } do {
+ // CHECK: ^bb0(%alice: index, %bob: index)
+ ^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
+ %c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
+ // CHECK: scf.yield %harriet#1, %harriet#1
+ scf.yield %c2, %c2 : index, index
+ } loc("kevin")
+ return
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+#trait = {
+ iterator_types = ["parallel"],
+ indexing_maps = [#map, #map, #map]
+}
+
+// CHECK-LABEL: test_op_asm_interface
+func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %c0 = arith.constant
+ %0 = arith.constant 0 : index
+ // CHECK: %foo = arith.constant
+ %1 = arith.constant 1 : index loc("foo")
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ linalg.yield %a, %a : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
+ ^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
+ // CHECK: linalg.yield %alice, %steve
+ linalg.yield %b, %c : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_pass
+func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg2 = %c0 to %c4 step %c1 {
+ // CHECK-PASS-PRESERVE: %foo = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo
+ // CHECK-PASS-PRESERVE: %foo_1 = memref.load
+ // CHECK-PASS-PRESERVE: memref.store %foo_1
+ %0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
+ memref.store %0, %arg1[%arg2] : memref<4xf32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list