[Mlir-commits] [mlir] [MLIR][LLVM] DI Expression Rewrite & Legalization (PR #77541)

Billy Zhu llvmlistbot at llvm.org
Wed Jan 10 12:05:12 PST 2024


https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/77541

>From fc3c7a080f0a93a845252b9ca313a9747073398a Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jan 2024 14:14:51 -0800
Subject: [PATCH 1/5] implement rewriter and first pattern

---
 .../Transforms/DIExpressionLegalization.h     | 44 +++++++++++
 .../LLVMIR/Transforms/DIExpressionRewriter.h  | 66 +++++++++++++++++
 .../Dialect/LLVMIR/Transforms/CMakeLists.txt  |  2 +
 .../Transforms/DIExpressionLegalization.cpp   | 61 +++++++++++++++
 .../Transforms/DIExpressionRewriter.cpp       | 74 +++++++++++++++++++
 .../LLVMIR/Transforms/LegalizeForExport.cpp   |  2 +
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  2 +
 .../LLVMIR/di-expression-legalization.mlir    | 42 +++++++++++
 8 files changed, 293 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
 create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
 create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
 create mode 100644 mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
new file mode 100644
index 00000000000000..59ac4b9c933124
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
@@ -0,0 +1,44 @@
+//===- DIExpressionLegalization.h - DIExpression Legalization Patterns ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declarations for known legalization patterns for DIExpressions that should
+// be performed before translation into llvm.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
+
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
+
+namespace mlir {
+namespace LLVM {
+
+//===----------------------------------------------------------------------===//
+// Rewrite Patterns
+//===----------------------------------------------------------------------===//
+
+/// Adjacent DW_OP_LLVM_fragment should be merged into one.
+class MergeFragments : public DIExpressionRewriter::ExprRewritePattern {
+public:
+  OpIterT match(OpIterRange operators) const override;
+  SmallVector<OperatorT> replace(OpIterRange operators) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// Runner
+//===----------------------------------------------------------------------===//
+
+/// Register all known legalization patterns declared here and apply them to
+/// all ops in `op`.
+void legalizeDIExpressionsRecursively(Operation *op);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
new file mode 100644
index 00000000000000..c1170dc49666ef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
@@ -0,0 +1,66 @@
+//===- DIExpressionRewriter.h - Rewriter for DIExpression operators -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A driver for running rewrite patterns on DIExpression operators.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include <deque>
+
+namespace mlir {
+namespace LLVM {
+
+/// Rewriter for DIExpressionAttr.
+///
+/// Users of this rewriter register their own rewrite patterns. Each pattern
+/// matches on a contiguous range of LLVM DIExpressionElemAttrs, and can be
+/// used to rewrite it into a new range of DIExpressionElemAttrs of any length.
+class DIExpressionRewriter {
+public:
+  using OperatorT = LLVM::DIExpressionElemAttr;
+
+  class ExprRewritePattern {
+  public:
+    using OperatorT = DIExpressionRewriter::OperatorT;
+    using OpIterT = std::deque<OperatorT>::const_iterator;
+    using OpIterRange = llvm::iterator_range<OpIterT>;
+
+    virtual ~ExprRewritePattern() = default;
+    /// Check whether a particular prefix of operators matches this pattern.
+    /// The provided argument is guaranteed non-empty.
+    /// Return the iterator after the last matched element.
+    virtual OpIterT match(OpIterRange) const = 0;
+    /// Replace the operators with a new list of operators.
+    /// The provided argument is guaranteed to be the same length as returned
+    /// by the `match` function.
+    virtual SmallVector<OperatorT> replace(OpIterRange) const = 0;
+  };
+
+  /// Register a rewrite pattern with the simplifier.
+  /// Rewriter patterns are attempted in the order of registration.
+  void addPattern(std::unique_ptr<ExprRewritePattern> pattern);
+
+  /// Simplify a DIExpression according to all the patterns registered.
+  /// A non-negative `maxNumRewrites` will limit the number of rewrites this
+  /// simplifier applies.
+  LLVM::DIExpressionAttr simplify(LLVM::DIExpressionAttr expr,
+                                  int64_t maxNumRewrites = -1) const;
+
+private:
+  /// The registered patterns.
+  SmallVector<std::unique_ptr<ExprRewritePattern>> patterns;
+};
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index 47a2a251bf3e8b..c80494a440116b 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRLLVMIRTransforms
   AddComdats.cpp
+  DIExpressionLegalization.cpp
+  DIExpressionRewriter.cpp
   DIScopeForLLVMFuncOp.cpp
   LegalizeForExport.cpp
   OptimizeForNVVM.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
new file mode 100644
index 00000000000000..0dce5d02175746
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
@@ -0,0 +1,61 @@
+//===- DIExpressionLegalization.cpp - DIExpression Legalization Patterns --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
+
+#include "llvm/BinaryFormat/Dwarf.h"
+
+using namespace mlir;
+using namespace LLVM;
+
+//===----------------------------------------------------------------------===//
+// MergeFragments
+//===----------------------------------------------------------------------===//
+
+MergeFragments::OpIterT MergeFragments::match(OpIterRange operators) const {
+  OpIterT it = operators.begin();
+  if (it == operators.end() ||
+      it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment)
+    return operators.begin();
+
+  ++it;
+  if (it == operators.end() ||
+      it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment)
+    return operators.begin();
+
+  return ++it;
+}
+
+SmallVector<MergeFragments::OperatorT>
+MergeFragments::replace(OpIterRange operators) const {
+  OpIterT it = operators.begin();
+  OperatorT first = *(it++);
+  OperatorT second = *it;
+  // Add offsets & select the size of the earlier operator (the one closer to
+  // the IR value).
+  uint64_t offset = first.getArguments()[0] + second.getArguments()[0];
+  uint64_t size = first.getArguments()[1];
+  OperatorT newOp = OperatorT::get(
+      first.getContext(), llvm::dwarf::DW_OP_LLVM_fragment, {offset, size});
+  return SmallVector<OperatorT>{newOp};
+}
+
+//===----------------------------------------------------------------------===//
+// Runner
+//===----------------------------------------------------------------------===//
+
+void mlir::LLVM::legalizeDIExpressionsRecursively(Operation *op) {
+  LLVM::DIExpressionRewriter rewriter;
+  rewriter.addPattern(std::make_unique<MergeFragments>());
+
+  mlir::AttrTypeReplacer replacer;
+  replacer.addReplacement([&rewriter](LLVM::DIExpressionAttr expr) {
+    return rewriter.simplify(expr);
+  });
+  replacer.recursivelyReplaceElementsIn(op);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
new file mode 100644
index 00000000000000..b17b684082b13a
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -0,0 +1,74 @@
+//===- DIExpressionRewriter.cpp - Rewriter for DIExpression operators -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace LLVM;
+
+#define DEBUG_TYPE "llvm-di-expression-simplifier"
+
+//===----------------------------------------------------------------------===//
+// DIExpressionRewriter
+//===----------------------------------------------------------------------===//
+
+void DIExpressionRewriter::addPattern(
+    std::unique_ptr<ExprRewritePattern> pattern) {
+  patterns.emplace_back(std::move(pattern));
+}
+
+DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
+                                                int64_t maxNumRewrites) const {
+  ArrayRef<OperatorT> operators = expr.getOperations();
+
+  // `inputs` contains the unprocessed postfix of operators.
+  // `result` contains the already finalized prefix of operators.
+  // Invariant: concat(result, inputs) is equivalent to `operators` after some
+  // application of the rewrite patterns.
+  // Using a deque for inputs so that we have efficient front insertion and
+  // removal. Random access is not necessary for patterns.
+  std::deque<OperatorT> inputs(operators.begin(), operators.end());
+  SmallVector<OperatorT> result;
+
+  int64_t numRewrites = 0;
+  while (!inputs.empty() &&
+         (maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {
+    bool foundMatch = false;
+    for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
+      ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
+      if (matchEnd == inputs.begin())
+        continue;
+
+      foundMatch = true;
+      SmallVector<OperatorT> replacement =
+          pattern->replace(llvm::make_range(inputs.cbegin(), matchEnd));
+      inputs.erase(inputs.begin(), matchEnd);
+      inputs.insert(inputs.begin(), replacement.begin(), replacement.end());
+      ++numRewrites;
+      break;
+    }
+
+    if (!foundMatch) {
+      // If no match, pass along the current operator.
+      result.push_back(inputs.front());
+      inputs.pop_front();
+    }
+  }
+
+  if (maxNumRewrites >= 0 && numRewrites >= maxNumRewrites) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
+               << maxNumRewrites << ")\n");
+    // Skip rewriting the rest.
+    result.append(inputs.begin(), inputs.end());
+  }
+
+  return LLVM::DIExpressionAttr::get(expr.getContext(), result);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
index 61c1378d961210..1ac994fa5fb780 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -79,6 +80,7 @@ struct LegalizeForExportPass
     : public LLVM::impl::LLVMLegalizeForExportBase<LegalizeForExportPass> {
   void runOnOperation() override {
     LLVM::ensureDistinctSuccessors(getOperation());
+    LLVM::legalizeDIExpressionsRecursively(getOperation());
   }
 };
 } // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ce46a194ea7d9f..fbbfb5b83eb609 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
+#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
@@ -1568,6 +1569,7 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
     return nullptr;
 
   LLVM::ensureDistinctSuccessors(module);
+  LLVM::legalizeDIExpressionsRecursively(module);
 
   ModuleTranslation translator(module, std::move(llvmModule));
   llvm::IRBuilder<> llvmBuilder(llvmContext);
diff --git a/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir b/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir
new file mode 100644
index 00000000000000..60fbc8135be62d
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt -llvm-legalize-for-export --split-input-file  %s | FileCheck %s -check-prefix=CHECK-OPT
+// RUN: mlir-translate -mlir-to-llvmir --split-input-file %s | FileCheck %s -check-prefix=CHECK-TRANSLATE
+
+#di_file = #llvm.di_file<"foo.c" in "/mlir/">
+#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, sourceLanguage = DW_LANG_C, file = #di_file, producer = "MLIR", isOptimized = true, emissionKind = Full>
+#di_subprogram = #llvm.di_subprogram<compileUnit = #di_compile_unit, scope = #di_file, name = "simplify", file = #di_file, subprogramFlags = Definition>
+#i32_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "i32", sizeInBits = 32, encoding = DW_ATE_unsigned>
+#i8_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "i8", sizeInBits = 8, encoding = DW_ATE_unsigned>
+
+// struct0: {i8, i32}
+#struct0_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct0_first", baseType = #i8_type, sizeInBits = 8, alignInBits = 8>
+#struct0_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct0_second", baseType = #i32_type, sizeInBits = 32, alignInBits = 32, offsetInBits = 32>
+#struct0 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct0", sizeInBits = 64, alignInBits = 32, elements = #struct0_first, #struct0_second>
+
+// struct1: {i8, struct0}
+#struct1_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct1_first", baseType = #i8_type, sizeInBits = 8, alignInBits = 8>
+#struct1_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct1_second", baseType = #struct0, sizeInBits = 64, alignInBits = 32>
+#struct1 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct1", sizeInBits = 96, alignInBits = 32, elements = #struct1_first, #struct1_second>
+
+// struct2: {i32, struct1}
+#struct2_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct2_first", baseType = #i32_type, sizeInBits = 32, alignInBits = 32>
+#struct2_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct2_second", baseType = #struct1, sizeInBits = 96, alignInBits = 32>
+#struct2 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct2", sizeInBits = 128, alignInBits = 32, elements = #struct2_first, #struct2_second>
+
+#var0 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct0_var", file = #di_file, line = 10, alignInBits = 32, type = #struct0>
+#var1 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct1_var", file = #di_file, line = 10, alignInBits = 32, type = #struct1>
+#var2 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct2_var", file = #di_file, line = 10, alignInBits = 32, type = #struct2>
+
+#loc = loc("test.mlir":0:0)
+
+llvm.func @merge_fragments(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+  // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]>
+  // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 32, 32))
+  llvm.intr.dbg.value #var0 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]> = %arg0 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
+  // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(64, 32)]>
+  // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 64, 32))
+  llvm.intr.dbg.value #var1 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64)]> = %arg1 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
+  // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(96, 32)]>
+  // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 96, 32))
+  llvm.intr.dbg.value #var2 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64), DW_OP_LLVM_fragment(32, 96)]> = %arg2 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
+  llvm.return
+}

>From c68be51f78735189c7c49115431668ce51703df2 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 10:04:36 -0800
Subject: [PATCH 2/5] Apply suggestions from code review

Co-authored-by: Tobias Gysi <tobias.gysi at nextsilicon.com>
---
 .../mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h       | 2 +-
 mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
index c1170dc49666ef..48eafb319b4816 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
@@ -35,7 +35,7 @@ class DIExpressionRewriter {
     using OpIterRange = llvm::iterator_range<OpIterT>;
 
     virtual ~ExprRewritePattern() = default;
-    /// Check whether a particular prefix of operators matches this pattern.
+    /// Checks whether a particular prefix of operators matches this pattern.
     /// The provided argument is guaranteed non-empty.
     /// Return the iterator after the last matched element.
     virtual OpIterT match(OpIterRange) const = 0;
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
index 0dce5d02175746..7d3170bb968219 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
@@ -53,7 +53,7 @@ void mlir::LLVM::legalizeDIExpressionsRecursively(Operation *op) {
   LLVM::DIExpressionRewriter rewriter;
   rewriter.addPattern(std::make_unique<MergeFragments>());
 
-  mlir::AttrTypeReplacer replacer;
+  AttrTypeReplacer replacer;
   replacer.addReplacement([&rewriter](LLVM::DIExpressionAttr expr) {
     return rewriter.simplify(expr);
   });

>From edbe397ba46e6e10d57fd7fa9438092bbb5218a5 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 10:08:16 -0800
Subject: [PATCH 3/5] add example in comments

---
 .../Dialect/LLVMIR/Transforms/DIExpressionLegalization.h   | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
index 59ac4b9c933124..2faf19b788b3a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h
@@ -24,6 +24,13 @@ namespace LLVM {
 //===----------------------------------------------------------------------===//
 
 /// Adjacent DW_OP_LLVM_fragment should be merged into one.
+///
+/// E.g.
+///   #llvm.di_expression<[
+///     DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64)
+///   ]>
+/// =>
+///   #llvm.di_expression<[DW_OP_LLVM_fragment(64, 32)]>
 class MergeFragments : public DIExpressionRewriter::ExprRewritePattern {
 public:
   OpIterT match(OpIterRange operators) const override;

>From 1c4d45b97206e5feb9da5004f5c32a812628d8fa Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 10:13:01 -0800
Subject: [PATCH 4/5] use optional

---
 .../LLVMIR/Transforms/DIExpressionRewriter.h        | 13 +++++++------
 .../LLVMIR/Transforms/DIExpressionRewriter.cpp      | 13 +++++++------
 2 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
index 48eafb319b4816..2d9841518a633a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
@@ -45,15 +45,16 @@ class DIExpressionRewriter {
     virtual SmallVector<OperatorT> replace(OpIterRange) const = 0;
   };
 
-  /// Register a rewrite pattern with the simplifier.
-  /// Rewriter patterns are attempted in the order of registration.
+  /// Register a rewrite pattern with the rewriter.
+  /// Rewrite patterns are attempted in the order of registration.
   void addPattern(std::unique_ptr<ExprRewritePattern> pattern);
 
   /// Simplify a DIExpression according to all the patterns registered.
-  /// A non-negative `maxNumRewrites` will limit the number of rewrites this
-  /// simplifier applies.
-  LLVM::DIExpressionAttr simplify(LLVM::DIExpressionAttr expr,
-                                  int64_t maxNumRewrites = -1) const;
+  /// An optional `maxNumRewrites` can be passed to limit the number of rewrites
+  /// that gets applied.
+  LLVM::DIExpressionAttr
+  simplify(LLVM::DIExpressionAttr expr,
+           std::optional<uint64_t> maxNumRewrites = {}) const;
 
 private:
   /// The registered patterns.
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index b17b684082b13a..0b8cfa7c4ee0f4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -24,8 +24,9 @@ void DIExpressionRewriter::addPattern(
   patterns.emplace_back(std::move(pattern));
 }
 
-DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
-                                                int64_t maxNumRewrites) const {
+DIExpressionAttr
+DIExpressionRewriter::simplify(DIExpressionAttr expr,
+                               std::optional<uint64_t> maxNumRewrites) const {
   ArrayRef<OperatorT> operators = expr.getOperations();
 
   // `inputs` contains the unprocessed postfix of operators.
@@ -37,9 +38,9 @@ DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
   std::deque<OperatorT> inputs(operators.begin(), operators.end());
   SmallVector<OperatorT> result;
 
-  int64_t numRewrites = 0;
-  while (!inputs.empty() &&
-         (maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {
+  uint64_t numRewrites = 0;
+  while (!inputs.empty() && (!maxNumRewrites.has_value() ||
+                             numRewrites < maxNumRewrites.value())) {
     bool foundMatch = false;
     for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
       ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
@@ -62,7 +63,7 @@ DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
     }
   }
 
-  if (maxNumRewrites >= 0 && numRewrites >= maxNumRewrites) {
+  if (maxNumRewrites.has_value() && numRewrites >= maxNumRewrites.value()) {
     LLVM_DEBUG(llvm::dbgs()
                << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
                << maxNumRewrites << ")\n");

>From 9ba7c49ec59e4d37f95323cd576aaf9c114068b5 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 10 Jan 2024 12:04:56 -0800
Subject: [PATCH 5/5] style

---
 mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index 0b8cfa7c4ee0f4..6fdb2f8c196478 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -39,8 +39,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
   SmallVector<OperatorT> result;
 
   uint64_t numRewrites = 0;
-  while (!inputs.empty() && (!maxNumRewrites.has_value() ||
-                             numRewrites < maxNumRewrites.value())) {
+  while (!inputs.empty() &&
+         (!maxNumRewrites || numRewrites < *maxNumRewrites)) {
     bool foundMatch = false;
     for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
       ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
@@ -63,7 +63,7 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
     }
   }
 
-  if (maxNumRewrites.has_value() && numRewrites >= maxNumRewrites.value()) {
+  if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
     LLVM_DEBUG(llvm::dbgs()
                << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
                << maxNumRewrites << ")\n");



More information about the Mlir-commits mailing list