[Mlir-commits] [mlir] [MLIR][LLVM] DI Expression Rewrite & Legalization (PR #77541)
Billy Zhu
llvmlistbot at llvm.org
Wed Jan 10 12:05:12 PST 2024
https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/77541
>From fc3c7a080f0a93a845252b9ca313a9747073398a Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jan 2024 14:14:51 -0800
Subject: [PATCH 1/5] implement rewriter and first pattern
---
.../Transforms/DIExpressionLegalization.h | 44 +++++++++++
.../LLVMIR/Transforms/DIExpressionRewriter.h | 66 +++++++++++++++++
.../Dialect/LLVMIR/Transforms/CMakeLists.txt | 2 +
.../Transforms/DIExpressionLegalization.cpp | 61 +++++++++++++++
.../Transforms/DIExpressionRewriter.cpp | 74 +++++++++++++++++++
.../LLVMIR/Transforms/LegalizeForExport.cpp | 2 +
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 +
.../LLVMIR/di-expression-legalization.mlir | 42 +++++++++++
8 files changed, 293 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
create mode 100644 mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
new file mode 100644
index 00000000000000..59ac4b9c933124
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
@@ -0,0 +1,44 @@
+//===- DIExpressionLegalization.h - DIExpression Legalization Patterns ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declarations for known legalization patterns for DIExpressions that should
+// be performed before translation into llvm.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
+
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
+
+namespace mlir {
+namespace LLVM {
+
+//===----------------------------------------------------------------------===//
+// Rewrite Patterns
+//===----------------------------------------------------------------------===//
+
+/// Adjacent DW_OP_LLVM_fragment should be merged into one.
+class MergeFragments : public DIExpressionRewriter::ExprRewritePattern {
+public:
+ OpIterT match(OpIterRange operators) const override;
+ SmallVector<OperatorT> replace(OpIterRange operators) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// Runner
+//===----------------------------------------------------------------------===//
+
+/// Register all known legalization patterns declared here and apply them to
+/// all ops in `op`.
+void legalizeDIExpressionsRecursively(Operation *op);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
new file mode 100644
index 00000000000000..c1170dc49666ef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
@@ -0,0 +1,66 @@
+//===- DIExpressionRewriter.h - Rewriter for DIExpression operators -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A driver for running rewrite patterns on DIExpression operators.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include <deque>
+
+namespace mlir {
+namespace LLVM {
+
+/// Rewriter for DIExpressionAttr.
+///
+/// Users of this rewriter register their own rewrite patterns. Each pattern
+/// matches on a contiguous range of LLVM DIExpressionElemAttrs, and can be
+/// used to rewrite it into a new range of DIExpressionElemAttrs of any length.
+class DIExpressionRewriter {
+public:
+ using OperatorT = LLVM::DIExpressionElemAttr;
+
+ class ExprRewritePattern {
+ public:
+ using OperatorT = DIExpressionRewriter::OperatorT;
+ using OpIterT = std::deque<OperatorT>::const_iterator;
+ using OpIterRange = llvm::iterator_range<OpIterT>;
+
+ virtual ~ExprRewritePattern() = default;
+ /// Check whether a particular prefix of operators matches this pattern.
+ /// The provided argument is guaranteed non-empty.
+ /// Return the iterator after the last matched element.
+ virtual OpIterT match(OpIterRange) const = 0;
+ /// Replace the operators with a new list of operators.
+ /// The provided argument is guaranteed to be the same length as returned
+ /// by the `match` function.
+ virtual SmallVector<OperatorT> replace(OpIterRange) const = 0;
+ };
+
+ /// Register a rewrite pattern with the simplifier.
+ /// Rewriter patterns are attempted in the order of registration.
+ void addPattern(std::unique_ptr<ExprRewritePattern> pattern);
+
+ /// Simplify a DIExpression according to all the patterns registered.
+ /// A non-negative `maxNumRewrites` will limit the number of rewrites this
+ /// simplifier applies.
+ LLVM::DIExpressionAttr simplify(LLVM::DIExpressionAttr expr,
+ int64_t maxNumRewrites = -1) const;
+
+private:
+ /// The registered patterns.
+ SmallVector<std::unique_ptr<ExprRewritePattern>> patterns;
+};
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index 47a2a251bf3e8b..c80494a440116b 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRLLVMIRTransforms
AddComdats.cpp
+ DIExpressionLegalization.cpp
+ DIExpressionRewriter.cpp
DIScopeForLLVMFuncOp.cpp
LegalizeForExport.cpp
OptimizeForNVVM.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
new file mode 100644
index 00000000000000..0dce5d02175746
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
@@ -0,0 +1,61 @@
+//===- DIExpressionLegalization.cpp - DIExpression Legalization Patterns --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
+
+#include "llvm/BinaryFormat/Dwarf.h"
+
+using namespace mlir;
+using namespace LLVM;
+
+//===----------------------------------------------------------------------===//
+// MergeFragments
+//===----------------------------------------------------------------------===//
+
+MergeFragments::OpIterT MergeFragments::match(OpIterRange operators) const {
+ OpIterT it = operators.begin();
+ if (it == operators.end() ||
+ it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment)
+ return operators.begin();
+
+ ++it;
+ if (it == operators.end() ||
+ it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment)
+ return operators.begin();
+
+ return ++it;
+}
+
+SmallVector<MergeFragments::OperatorT>
+MergeFragments::replace(OpIterRange operators) const {
+ OpIterT it = operators.begin();
+ OperatorT first = *(it++);
+ OperatorT second = *it;
+ // Add offsets & select the size of the earlier operator (the one closer to
+ // the IR value).
+ uint64_t offset = first.getArguments()[0] + second.getArguments()[0];
+ uint64_t size = first.getArguments()[1];
+ OperatorT newOp = OperatorT::get(
+ first.getContext(), llvm::dwarf::DW_OP_LLVM_fragment, {offset, size});
+ return SmallVector<OperatorT>{newOp};
+}
+
+//===----------------------------------------------------------------------===//
+// Runner
+//===----------------------------------------------------------------------===//
+
+void mlir::LLVM::legalizeDIExpressionsRecursively(Operation *op) {
+ LLVM::DIExpressionRewriter rewriter;
+ rewriter.addPattern(std::make_unique<MergeFragments>());
+
+ mlir::AttrTypeReplacer replacer;
+ replacer.addReplacement([&rewriter](LLVM::DIExpressionAttr expr) {
+ return rewriter.simplify(expr);
+ });
+ replacer.recursivelyReplaceElementsIn(op);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
new file mode 100644
index 00000000000000..b17b684082b13a
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -0,0 +1,74 @@
+//===- DIExpressionRewriter.cpp - Rewriter for DIExpression operators -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace LLVM;
+
+#define DEBUG_TYPE "llvm-di-expression-simplifier"
+
+//===----------------------------------------------------------------------===//
+// DIExpressionRewriter
+//===----------------------------------------------------------------------===//
+
+void DIExpressionRewriter::addPattern(
+ std::unique_ptr<ExprRewritePattern> pattern) {
+ patterns.emplace_back(std::move(pattern));
+}
+
+DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
+ int64_t maxNumRewrites) const {
+ ArrayRef<OperatorT> operators = expr.getOperations();
+
+ // `inputs` contains the unprocessed postfix of operators.
+ // `result` contains the already finalized prefix of operators.
+ // Invariant: concat(result, inputs) is equivalent to `operators` after some
+ // application of the rewrite patterns.
+ // Using a deque for inputs so that we have efficient front insertion and
+ // removal. Random access is not necessary for patterns.
+ std::deque<OperatorT> inputs(operators.begin(), operators.end());
+ SmallVector<OperatorT> result;
+
+ int64_t numRewrites = 0;
+ while (!inputs.empty() &&
+ (maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {
+ bool foundMatch = false;
+ for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
+ ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
+ if (matchEnd == inputs.begin())
+ continue;
+
+ foundMatch = true;
+ SmallVector<OperatorT> replacement =
+ pattern->replace(llvm::make_range(inputs.cbegin(), matchEnd));
+ inputs.erase(inputs.begin(), matchEnd);
+ inputs.insert(inputs.begin(), replacement.begin(), replacement.end());
+ ++numRewrites;
+ break;
+ }
+
+ if (!foundMatch) {
+ // If no match, pass along the current operator.
+ result.push_back(inputs.front());
+ inputs.pop_front();
+ }
+ }
+
+ if (maxNumRewrites >= 0 && numRewrites >= maxNumRewrites) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
+ << maxNumRewrites << ")\n");
+ // Skip rewriting the rest.
+ result.append(inputs.begin(), inputs.end());
+ }
+
+ return LLVM::DIExpressionAttr::get(expr.getContext(), result);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
index 61c1378d961210..1ac994fa5fb780 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -79,6 +80,7 @@ struct LegalizeForExportPass
: public LLVM::impl::LLVMLegalizeForExportBase<LegalizeForExportPass> {
void runOnOperation() override {
LLVM::ensureDistinctSuccessors(getOperation());
+ LLVM::legalizeDIExpressionsRecursively(getOperation());
}
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ce46a194ea7d9f..fbbfb5b83eb609 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
@@ -1568,6 +1569,7 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
return nullptr;
LLVM::ensureDistinctSuccessors(module);
+ LLVM::legalizeDIExpressionsRecursively(module);
ModuleTranslation translator(module, std::move(llvmModule));
llvm::IRBuilder<> llvmBuilder(llvmContext);
diff --git a/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir b/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir
new file mode 100644
index 00000000000000..60fbc8135be62d
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt -llvm-legalize-for-export --split-input-file %s | FileCheck %s -check-prefix=CHECK-OPT
+// RUN: mlir-translate -mlir-to-llvmir --split-input-file %s | FileCheck %s -check-prefix=CHECK-TRANSLATE
+
+#di_file = #llvm.di_file<"foo.c" in "/mlir/">
+#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, sourceLanguage = DW_LANG_C, file = #di_file, producer = "MLIR", isOptimized = true, emissionKind = Full>
+#di_subprogram = #llvm.di_subprogram<compileUnit = #di_compile_unit, scope = #di_file, name = "simplify", file = #di_file, subprogramFlags = Definition>
+#i32_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "i32", sizeInBits = 32, encoding = DW_ATE_unsigned>
+#i8_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "i8", sizeInBits = 8, encoding = DW_ATE_unsigned>
+
+// struct0: {i8, i32}
+#struct0_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct0_first", baseType = #i8_type, sizeInBits = 8, alignInBits = 8>
+#struct0_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct0_second", baseType = #i32_type, sizeInBits = 32, alignInBits = 32, offsetInBits = 32>
+#struct0 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct0", sizeInBits = 64, alignInBits = 32, elements = #struct0_first, #struct0_second>
+
+// struct1: {i8, struct0}
+#struct1_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct1_first", baseType = #i8_type, sizeInBits = 8, alignInBits = 8>
+#struct1_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct1_second", baseType = #struct0, sizeInBits = 64, alignInBits = 32>
+#struct1 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct1", sizeInBits = 96, alignInBits = 32, elements = #struct1_first, #struct1_second>
+
+// struct2: {i32, struct1}
+#struct2_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct2_first", baseType = #i32_type, sizeInBits = 32, alignInBits = 32>
+#struct2_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct2_second", baseType = #struct1, sizeInBits = 96, alignInBits = 32>
+#struct2 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct2", sizeInBits = 128, alignInBits = 32, elements = #struct2_first, #struct2_second>
+
+#var0 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct0_var", file = #di_file, line = 10, alignInBits = 32, type = #struct0>
+#var1 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct1_var", file = #di_file, line = 10, alignInBits = 32, type = #struct1>
+#var2 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct2_var", file = #di_file, line = 10, alignInBits = 32, type = #struct2>
+
+#loc = loc("test.mlir":0:0)
+
+llvm.func @merge_fragments(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+ // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]>
+ // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 32, 32))
+ llvm.intr.dbg.value #var0 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]> = %arg0 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
+ // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(64, 32)]>
+ // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 64, 32))
+ llvm.intr.dbg.value #var1 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64)]> = %arg1 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
+ // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(96, 32)]>
+ // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 96, 32))
+ llvm.intr.dbg.value #var2 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64), DW_OP_LLVM_fragment(32, 96)]> = %arg2 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
+ llvm.return
+}
>From c68be51f78735189c7c49115431668ce51703df2 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 10:04:36 -0800
Subject: [PATCH 2/5] Apply suggestions from code review
Co-authored-by: Tobias Gysi <tobias.gysi at nextsilicon.com>
---
.../mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h | 2 +-
mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
index c1170dc49666ef..48eafb319b4816 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
@@ -35,7 +35,7 @@ class DIExpressionRewriter {
using OpIterRange = llvm::iterator_range<OpIterT>;
virtual ~ExprRewritePattern() = default;
- /// Check whether a particular prefix of operators matches this pattern.
+ /// Checks whether a particular prefix of operators matches this pattern.
/// The provided argument is guaranteed non-empty.
/// Return the iterator after the last matched element.
virtual OpIterT match(OpIterRange) const = 0;
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
index 0dce5d02175746..7d3170bb968219 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
@@ -53,7 +53,7 @@ void mlir::LLVM::legalizeDIExpressionsRecursively(Operation *op) {
LLVM::DIExpressionRewriter rewriter;
rewriter.addPattern(std::make_unique<MergeFragments>());
- mlir::AttrTypeReplacer replacer;
+ AttrTypeReplacer replacer;
replacer.addReplacement([&rewriter](LLVM::DIExpressionAttr expr) {
return rewriter.simplify(expr);
});
>From edbe397ba46e6e10d57fd7fa9438092bbb5218a5 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 10:08:16 -0800
Subject: [PATCH 3/5] add example in comments
---
.../Dialect/LLVMIR/Transforms/DIExpressionLegalization.h | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
index 59ac4b9c933124..2faf19b788b3a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
@@ -24,6 +24,13 @@ namespace LLVM {
//===----------------------------------------------------------------------===//
/// Adjacent DW_OP_LLVM_fragment should be merged into one.
+///
+/// E.g.
+/// #llvm.di_expression<[
+/// DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64)
+/// ]>
+/// =>
+/// #llvm.di_expression<[DW_OP_LLVM_fragment(64, 32)]>
class MergeFragments : public DIExpressionRewriter::ExprRewritePattern {
public:
OpIterT match(OpIterRange operators) const override;
>From 1c4d45b97206e5feb9da5004f5c32a812628d8fa Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 10:13:01 -0800
Subject: [PATCH 4/5] use optional
---
.../LLVMIR/Transforms/DIExpressionRewriter.h | 13 +++++++------
.../LLVMIR/Transforms/DIExpressionRewriter.cpp | 13 +++++++------
2 files changed, 14 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
index 48eafb319b4816..2d9841518a633a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
@@ -45,15 +45,16 @@ class DIExpressionRewriter {
virtual SmallVector<OperatorT> replace(OpIterRange) const = 0;
};
- /// Register a rewrite pattern with the simplifier.
- /// Rewriter patterns are attempted in the order of registration.
+ /// Register a rewrite pattern with the rewriter.
+ /// Rewrite patterns are attempted in the order of registration.
void addPattern(std::unique_ptr<ExprRewritePattern> pattern);
/// Simplify a DIExpression according to all the patterns registered.
- /// A non-negative `maxNumRewrites` will limit the number of rewrites this
- /// simplifier applies.
- LLVM::DIExpressionAttr simplify(LLVM::DIExpressionAttr expr,
- int64_t maxNumRewrites = -1) const;
+ /// An optional `maxNumRewrites` can be passed to limit the number of rewrites
+ /// that gets applied.
+ LLVM::DIExpressionAttr
+ simplify(LLVM::DIExpressionAttr expr,
+ std::optional<uint64_t> maxNumRewrites = {}) const;
private:
/// The registered patterns.
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index b17b684082b13a..0b8cfa7c4ee0f4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -24,8 +24,9 @@ void DIExpressionRewriter::addPattern(
patterns.emplace_back(std::move(pattern));
}
-DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
- int64_t maxNumRewrites) const {
+DIExpressionAttr
+DIExpressionRewriter::simplify(DIExpressionAttr expr,
+ std::optional<uint64_t> maxNumRewrites) const {
ArrayRef<OperatorT> operators = expr.getOperations();
// `inputs` contains the unprocessed postfix of operators.
@@ -37,9 +38,9 @@ DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
std::deque<OperatorT> inputs(operators.begin(), operators.end());
SmallVector<OperatorT> result;
- int64_t numRewrites = 0;
- while (!inputs.empty() &&
- (maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {
+ uint64_t numRewrites = 0;
+ while (!inputs.empty() && (!maxNumRewrites.has_value() ||
+ numRewrites < maxNumRewrites.value())) {
bool foundMatch = false;
for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
@@ -62,7 +63,7 @@ DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
}
}
- if (maxNumRewrites >= 0 && numRewrites >= maxNumRewrites) {
+ if (maxNumRewrites.has_value() && numRewrites >= maxNumRewrites.value()) {
LLVM_DEBUG(llvm::dbgs()
<< "LLVMDIExpressionSimplifier exceeded max num rewrites ("
<< maxNumRewrites << ")\n");
>From 9ba7c49ec59e4d37f95323cd576aaf9c114068b5 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 12:04:56 -0800
Subject: [PATCH 5/5] style
---
mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index 0b8cfa7c4ee0f4..6fdb2f8c196478 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -39,8 +39,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
SmallVector<OperatorT> result;
uint64_t numRewrites = 0;
- while (!inputs.empty() && (!maxNumRewrites.has_value() ||
- numRewrites < maxNumRewrites.value())) {
+ while (!inputs.empty() &&
+ (!maxNumRewrites || numRewrites < *maxNumRewrites)) {
bool foundMatch = false;
for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
@@ -63,7 +63,7 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
}
}
- if (maxNumRewrites.has_value() && numRewrites >= maxNumRewrites.value()) {
+ if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
LLVM_DEBUG(llvm::dbgs()
<< "LLVMDIExpressionSimplifier exceeded max num rewrites ("
<< maxNumRewrites << ")\n");
More information about the Mlir-commits
mailing list