[Mlir-commits] [mlir] [mlir][LLVMIR] Add folder pass for `llvm.inttoptr` and `llvm.ptrtoint` (PR #143066)

Diego Caballero llvmlistbot at llvm.org
Thu Jun 5 22:02:32 PDT 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/143066

This PR is a follow-up to https://github.com/llvm/llvm-project/pull/141891. It introduces a pass that can fold `inttoptr(ptrtoint(x)) -> x` and `ptrtoint(inttoptr(x)) -> x`. The pass takes in a list of address space bitwidths and makes sure that the folding is applied only when it's safe.

>From 23924b917579882d56bd99ca429a8b1670e3e432 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Fri, 6 Jun 2025 04:52:17 +0000
Subject: [PATCH] [mlir][LLVMIR] Add folder pass for `llvm.inttoptr` and
 `llvm.ptrtoint`

This PR is a follow-up to https://github.com/llvm/llvm-project/pull/141891.
It introduces a pass that can fold `inttoptr(ptrtoint(x)) -> x` and
`ptrtoint(inttoptr(x)) -> x`. The pass takes in a list of address space
bitwidths and makes sure that the folding is applied only when it's safe.
---
 .../Transforms/IntToPtrPtrToIntFolding.h      |  43 ++++++
 .../mlir/Dialect/LLVMIR/Transforms/Passes.td  |  19 +++
 .../Dialect/LLVMIR/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/IntToPtrPtrToIntFolding.cpp    | 133 ++++++++++++++++++
 .../LLVMIR/inttoptr-ptrtoint-folding.mlir     | 100 +++++++++++++
 5 files changed, 296 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
 create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
 create mode 100644 mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
new file mode 100644
index 0000000000000..62a9c84241a39
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
@@ -0,0 +1,43 @@
+//===- IntToPtrPtrToIntFolding.h - IntToPtr/PtrToInt folding ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a pass that folds inttoptr/ptrtoint operation sequences.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace LLVM {
+
+#define GEN_PASS_DECL_FOLDINTTOPTRPTRTOINTPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+/// Populate patterns that fold inttoptr/ptrtoint op sequences such as:
+///
+///   * `inttoptr(ptrtoint(x))` -> `x`
+///   * `ptrtoint(inttoptr(x))` -> `x`
+///
+/// `addressSpaceBWs` contains the pointer bitwidth for each address space. If
+/// the pointer bitwidth information is not available for a specific address
+/// space, the folding for that address space is not performed.
+///
+/// TODO: Support DLTI.
+void populateIntToPtrPtrToIntFoldingPatterns(
+    RewritePatternSet &patterns, ArrayRef<unsigned> addressSpaceBWs);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index 961909d5c8d27..be45213e6b95e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -73,4 +73,23 @@ def DIScopeForLLVMFuncOpPass : Pass<"ensure-debug-info-scope-on-llvm-func", "::m
   ];
 }
 
+def FoldIntToPtrPtrToIntPass : Pass<"fold-llvm-inttoptr-ptrtoint", "LLVM::LLVMFuncOp"> {
+  let summary = "Fold inttoptr/ptrtoint operation sequences";
+  let description = [{
+    This pass folds sequences of inttoptr and ptrtoint operations that cancel
+    each other out. Specifically:
+      * inttoptr(ptrtoint(x)) -> x
+      * ptrtoint(inttoptr(x)) -> x
+
+    The pass takes a sequence of address space bitwidths to make sure folding
+    is safe. If the bitwidth information is not available for an address space,
+    the pass will not fold any operations involving that address space.
+  }];
+  let dependentDialects = ["LLVM::LLVMDialect"];
+  let options = [
+    ListOption<"addrSpaceBWs", "address-space-bitwidths", "unsigned",
+               "List of address space bitwidths sorted by associated index to each address space.">
+  ];
+}
+
 #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff0955c5d0e..b22280718c454 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
   DIExpressionRewriter.cpp
   DIScopeForLLVMFuncOp.cpp
   InlinerInterfaceImpl.cpp
+  IntToPtrPtrToIntFolding.cpp
   LegalizeForExport.cpp
   OptimizeForNVVM.cpp
   RequestCWrappers.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
new file mode 100644
index 0000000000000..c87a9cb2afe41
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
@@ -0,0 +1,133 @@
+//===- IntToPtrPtrToIntFolding.cpp - IntToPtr/PtrToInt folding ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that folds inttoptr/ptrtoint operation sequences.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "fold-llvm-inttoptr-ptrtoint"
+
+namespace mlir {
+namespace LLVM {
+
+#define GEN_PASS_DEF_FOLDINTTOPTRPTRTOINTPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return the bitwidth of a pointer or integer type. If the type is a pointer,
+/// return the bitwidth of the address space from `addrSpaceBWs`, if available.
+/// Return failure if the address space bitwidth is not available.
+static FailureOr<unsigned> getIntOrPtrBW(Type type,
+                                         ArrayRef<unsigned> addrSpaceBWs) {
+  if (auto ptrType = dyn_cast<LLVM::LLVMPointerType>(type)) {
+    unsigned addrSpace = ptrType.getAddressSpace();
+    if (addrSpace < addrSpaceBWs.size() && addrSpaceBWs[addrSpace] != 0)
+      return addrSpaceBWs[addrSpace];
+    return failure();
+  }
+
+  auto integerType = cast<IntegerType>(type);
+  return integerType.getWidth();
+}
+
+/// Check if folding inttoptr/ptrtoint is valid. Check that the original type
+/// matches the result type of the end-to-end conversion and that the input
+/// value is not truncated along the conversion chain.
+static LogicalResult canFoldIntToPtrPtrToInt(Type originalType,
+                                             Type intermediateType,
+                                             Type resultType,
+                                             ArrayRef<unsigned> addrSpaceBWs) {
+  // Check if the original type matches the result type.
+  // TODO: Support address space conversions?
+  // TODO: Support int trunc/ext?
+  if (originalType != resultType)
+    return failure();
+
+  // Make sure there is no data truncation with respect to the original type at
+  // any point during the conversion. Truncating the intermediate data is fine
+  // as long as the original data is not truncated.
+  auto originalBW = getIntOrPtrBW(originalType, addrSpaceBWs);
+  if (failed(originalBW))
+    return failure();
+
+  auto intermediateBW = getIntOrPtrBW(intermediateType, addrSpaceBWs);
+  if (failed(intermediateBW))
+    return failure();
+
+  if (*originalBW > *intermediateBW)
+    return failure();
+  return success();
+}
+
+/// Folds inttoptr(ptrtoint(x)) -> x or ptrtoint(inttoptr(x)) -> x.
+template <typename SrcConvOp, typename DstConvOp>
+class FoldIntToPtrPtrToInt : public OpRewritePattern<DstConvOp> {
+public:
+  FoldIntToPtrPtrToInt(MLIRContext *context, ArrayRef<unsigned> addrSpaceBWs)
+      : OpRewritePattern<DstConvOp>(context), addrSpaceBWs(addrSpaceBWs) {}
+
+  LogicalResult matchAndRewrite(DstConvOp dstConvOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcConvOp = dstConvOp.getArg().template getDefiningOp<SrcConvOp>();
+    if (!srcConvOp)
+      return failure();
+
+    // Check if folding is valid based on type matching and bitwidth
+    // information.
+    if (failed(canFoldIntToPtrPtrToInt(srcConvOp.getArg().getType(),
+                                       srcConvOp.getType(), dstConvOp.getType(),
+                                       addrSpaceBWs))) {
+      return failure();
+    }
+
+    rewriter.replaceOp(dstConvOp, srcConvOp.getArg());
+    return success();
+  }
+
+private:
+  SmallVector<unsigned> addrSpaceBWs;
+};
+
+/// Pass that folds inttoptr/ptrtoint operation sequences.
+struct FoldIntToPtrPtrToIntPass
+    : public LLVM::impl::FoldIntToPtrPtrToIntPassBase<
+          FoldIntToPtrPtrToIntPass> {
+  using Base =
+      LLVM::impl::FoldIntToPtrPtrToIntPassBase<FoldIntToPtrPtrToIntPass>;
+  using Base::FoldIntToPtrPtrToIntPassBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    LLVM::populateIntToPtrPtrToIntFoldingPatterns(patterns, addrSpaceBWs);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::LLVM::populateIntToPtrPtrToIntFoldingPatterns(
+    RewritePatternSet &patterns, ArrayRef<unsigned> addrSpaceBWs) {
+  patterns.add<FoldIntToPtrPtrToInt<LLVM::PtrToIntOp, LLVM::IntToPtrOp>,
+               FoldIntToPtrPtrToInt<LLVM::IntToPtrOp, LLVM::PtrToIntOp>>(
+      patterns.getContext(), addrSpaceBWs);
+}
diff --git a/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
new file mode 100644
index 0000000000000..30193eff93be2
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64}))" %s | FileCheck %s --check-prefixes=CHECK-64BIT,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=32}))" %s | FileCheck %s --check-prefixes=CHECK-32BIT,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64,32}))" %s | FileCheck %s --check-prefixes=CHECK-MULTI-ADDRSPACE,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint))" %s | FileCheck %s --check-prefixes=CHECK-DISABLED,CHECK-ALL
+
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_64bit
+//  CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr)
+llvm.func @test_inttoptr_ptrtoint_fold_64bit(%arg0: !llvm.ptr) -> !llvm.ptr {
+  // CHECK-64BIT-NOT: llvm.ptrtoint
+  // CHECK-64BIT-NOT: llvm.inttoptr
+  //     CHECK-64BIT: llvm.return %[[ARG]]
+
+  // CHECK-32BIT-NOT: llvm.ptrtoint
+  // CHECK-32BIT-NOT: llvm.inttoptr
+  //     CHECK-32BIT: llvm.return %[[ARG]]
+
+  // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+  // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+  //     CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+  // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+  // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+  // CHECK-DISABLED: llvm.return %[[PTR]]
+
+  %0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
+  %1 = llvm.inttoptr %0 : i64 to !llvm.ptr
+  llvm.return %1 : !llvm.ptr
+}
+
+// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_fold_64bit
+//  CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
+llvm.func @test_ptrtoint_inttoptr_fold_64bit(%arg0: i64) -> i64 {
+  // CHECK-64BIT-NOT: llvm.inttoptr
+  // CHECK-64BIT-NOT: llvm.ptrtoint
+  //     CHECK-64BIT: llvm.return %[[ARG]]
+
+  // CHECK-32BIT: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+  // CHECK-32BIT: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+  // CHECK-32BIT: llvm.return %[[PTR]]
+
+  // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+  // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+  //     CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+  // CHECK-DISABLED: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+  // CHECK-DISABLED: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+  // CHECK-DISABLED: llvm.return %[[PTR]]
+
+  %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
+  %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
+  llvm.return %1 : i64
+}
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_addrspace1_32bit
+//  CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
+llvm.func @test_inttoptr_ptrtoint_fold_addrspace1_32bit(%arg0: !llvm.ptr<1>) -> !llvm.ptr<1> {
+  // CHECK-64BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+  // CHECK-64BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+  // CHECK-64BIT: llvm.return %[[PTR]]
+
+  // CHECK-32BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+  // CHECK-32BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+  // CHECK-32BIT: llvm.return %[[PTR]]
+
+  // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+  // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+  //     CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+  // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+  // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+  // CHECK-DISABLED: llvm.return %[[PTR]]
+
+  %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i32
+  %1 = llvm.inttoptr %0 : i32 to !llvm.ptr<1>
+  llvm.return %1 : !llvm.ptr<1>
+}
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_type_mismatch
+//  CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
+llvm.func @test_inttoptr_ptrtoint_type_mismatch(%arg0: i64) -> i32 {
+  // CHECK-ALL: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+  // CHECK-ALL: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+  // CHECK-ALL: llvm.return %[[PTR]]
+
+  %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
+  %1 = llvm.ptrtoint %0 : !llvm.ptr to i32
+  llvm.return %1 : i32
+}
+
+// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_type_mismatch
+//  CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
+llvm.func @test_ptrtoint_inttoptr_type_mismatch(%arg0: !llvm.ptr<1>) -> !llvm.ptr<0> {
+  // CHECK-ALL: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+  // CHECK-ALL: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+  // CHECK-ALL: llvm.return %[[PTR]]
+  %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
+  %1 = llvm.inttoptr %0 : i64 to !llvm.ptr<0>
+  llvm.return %1 : !llvm.ptr<0>
+}



More information about the Mlir-commits mailing list