[Mlir-commits] [mlir] [mlir][LLVMIR] Add folder pass for `llvm.inttoptr` and `llvm.ptrtoint` (PR #143066)
Diego Caballero
llvmlistbot at llvm.org
Thu Jun 5 22:02:32 PDT 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/143066
This PR is a follow-up to https://github.com/llvm/llvm-project/pull/141891. It introduces a pass that can fold `inttoptr(ptrtoint(x)) -> x` and `ptrtoint(inttoptr(x)) -> x`. The pass takes in a list of address space bitwidths and makes sure that the folding is applied only when it's safe.
>From 23924b917579882d56bd99ca429a8b1670e3e432 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Fri, 6 Jun 2025 04:52:17 +0000
Subject: [PATCH] [mlir][LLVMIR] Add folder pass for `llvm.inttoptr` and
`llvm.ptrtoint`
This PR is a follow-up to https://github.com/llvm/llvm-project/pull/141891.
It introduces a pass that can fold `inttoptr(ptrtoint(x)) -> x` and
`ptrtoint(inttoptr(x)) -> x`. The pass takes in a list of address space
bitwidths and makes sure that the folding is applied only when it's safe.
---
.../Transforms/IntToPtrPtrToIntFolding.h | 43 ++++++
.../mlir/Dialect/LLVMIR/Transforms/Passes.td | 19 +++
.../Dialect/LLVMIR/Transforms/CMakeLists.txt | 1 +
.../Transforms/IntToPtrPtrToIntFolding.cpp | 133 ++++++++++++++++++
.../LLVMIR/inttoptr-ptrtoint-folding.mlir | 100 +++++++++++++
5 files changed, 296 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
create mode 100644 mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
new file mode 100644
index 0000000000000..62a9c84241a39
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
@@ -0,0 +1,43 @@
+//===- IntToPtrPtrToIntFolding.h - IntToPtr/PtrToInt folding ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a pass that folds inttoptr/ptrtoint operation sequences.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace LLVM {
+
+#define GEN_PASS_DECL_FOLDINTTOPTRPTRTOINTPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+/// Populate patterns that fold inttoptr/ptrtoint op sequences such as:
+///
+/// * `inttoptr(ptrtoint(x))` -> `x`
+/// * `ptrtoint(inttoptr(x))` -> `x`
+///
+/// `addressSpaceBWs` contains the pointer bitwidth for each address space. If
+/// the pointer bitwidth information is not available for a specific address
+/// space, the folding for that address space is not performed.
+///
+/// TODO: Support DLTI.
+void populateIntToPtrPtrToIntFoldingPatterns(
+ RewritePatternSet &patterns, ArrayRef<unsigned> addressSpaceBWs);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index 961909d5c8d27..be45213e6b95e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -73,4 +73,23 @@ def DIScopeForLLVMFuncOpPass : Pass<"ensure-debug-info-scope-on-llvm-func", "::m
];
}
+def FoldIntToPtrPtrToIntPass : Pass<"fold-llvm-inttoptr-ptrtoint", "LLVM::LLVMFuncOp"> {
+ let summary = "Fold inttoptr/ptrtoint operation sequences";
+ let description = [{
+ This pass folds sequences of inttoptr and ptrtoint operations that cancel
+ each other out. Specifically:
+ * inttoptr(ptrtoint(x)) -> x
+ * ptrtoint(inttoptr(x)) -> x
+
+ The pass takes a sequence of address space bitwidths to make sure folding
+ is safe. If the bitwidth information is not available for an address space,
+ the pass will not fold any operations involving that address space.
+ }];
+ let dependentDialects = ["LLVM::LLVMDialect"];
+ let options = [
+ ListOption<"addrSpaceBWs", "address-space-bitwidths", "unsigned",
+ "List of address space bitwidths sorted by associated index to each address space.">
+ ];
+}
+
#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff0955c5d0e..b22280718c454 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
DIExpressionRewriter.cpp
DIScopeForLLVMFuncOp.cpp
InlinerInterfaceImpl.cpp
+ IntToPtrPtrToIntFolding.cpp
LegalizeForExport.cpp
OptimizeForNVVM.cpp
RequestCWrappers.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
new file mode 100644
index 0000000000000..c87a9cb2afe41
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
@@ -0,0 +1,133 @@
+//===- IntToPtrPtrToIntFolding.cpp - IntToPtr/PtrToInt folding ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that folds inttoptr/ptrtoint operation sequences.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "fold-llvm-inttoptr-ptrtoint"
+
+namespace mlir {
+namespace LLVM {
+
+#define GEN_PASS_DEF_FOLDINTTOPTRPTRTOINTPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return the bitwidth of a pointer or integer type. If the type is a pointer,
+/// return the bitwidth of the address space from `addrSpaceBWs`, if available.
+/// Return failure if the address space bitwidth is not available.
+static FailureOr<unsigned> getIntOrPtrBW(Type type,
+ ArrayRef<unsigned> addrSpaceBWs) {
+ if (auto ptrType = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ unsigned addrSpace = ptrType.getAddressSpace();
+ if (addrSpace < addrSpaceBWs.size() && addrSpaceBWs[addrSpace] != 0)
+ return addrSpaceBWs[addrSpace];
+ return failure();
+ }
+
+ auto integerType = cast<IntegerType>(type);
+ return integerType.getWidth();
+}
+
+/// Check if folding inttoptr/ptrtoint is valid. Check that the original type
+/// matches the result type of the end-to-end conversion and that the input
+/// value is not truncated along the conversion chain.
+static LogicalResult canFoldIntToPtrPtrToInt(Type originalType,
+ Type intermediateType,
+ Type resultType,
+ ArrayRef<unsigned> addrSpaceBWs) {
+ // Check if the original type matches the result type.
+ // TODO: Support address space conversions?
+ // TODO: Support int trunc/ext?
+ if (originalType != resultType)
+ return failure();
+
+ // Make sure there is no data truncation with respect to the original type at
+ // any point during the conversion. Truncating the intermediate data is fine
+ // as long as the original data is not truncated.
+ auto originalBW = getIntOrPtrBW(originalType, addrSpaceBWs);
+ if (failed(originalBW))
+ return failure();
+
+ auto intermediateBW = getIntOrPtrBW(intermediateType, addrSpaceBWs);
+ if (failed(intermediateBW))
+ return failure();
+
+ if (*originalBW > *intermediateBW)
+ return failure();
+ return success();
+}
+
+/// Folds inttoptr(ptrtoint(x)) -> x or ptrtoint(inttoptr(x)) -> x.
+template <typename SrcConvOp, typename DstConvOp>
+class FoldIntToPtrPtrToInt : public OpRewritePattern<DstConvOp> {
+public:
+ FoldIntToPtrPtrToInt(MLIRContext *context, ArrayRef<unsigned> addrSpaceBWs)
+ : OpRewritePattern<DstConvOp>(context), addrSpaceBWs(addrSpaceBWs) {}
+
+ LogicalResult matchAndRewrite(DstConvOp dstConvOp,
+ PatternRewriter &rewriter) const override {
+ auto srcConvOp = dstConvOp.getArg().template getDefiningOp<SrcConvOp>();
+ if (!srcConvOp)
+ return failure();
+
+ // Check if folding is valid based on type matching and bitwidth
+ // information.
+ if (failed(canFoldIntToPtrPtrToInt(srcConvOp.getArg().getType(),
+ srcConvOp.getType(), dstConvOp.getType(),
+ addrSpaceBWs))) {
+ return failure();
+ }
+
+ rewriter.replaceOp(dstConvOp, srcConvOp.getArg());
+ return success();
+ }
+
+private:
+ SmallVector<unsigned> addrSpaceBWs;
+};
+
+/// Pass that folds inttoptr/ptrtoint operation sequences.
+struct FoldIntToPtrPtrToIntPass
+ : public LLVM::impl::FoldIntToPtrPtrToIntPassBase<
+ FoldIntToPtrPtrToIntPass> {
+ using Base =
+ LLVM::impl::FoldIntToPtrPtrToIntPassBase<FoldIntToPtrPtrToIntPass>;
+ using Base::FoldIntToPtrPtrToIntPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ LLVM::populateIntToPtrPtrToIntFoldingPatterns(patterns, addrSpaceBWs);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::LLVM::populateIntToPtrPtrToIntFoldingPatterns(
+ RewritePatternSet &patterns, ArrayRef<unsigned> addrSpaceBWs) {
+ patterns.add<FoldIntToPtrPtrToInt<LLVM::PtrToIntOp, LLVM::IntToPtrOp>,
+ FoldIntToPtrPtrToInt<LLVM::IntToPtrOp, LLVM::PtrToIntOp>>(
+ patterns.getContext(), addrSpaceBWs);
+}
diff --git a/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
new file mode 100644
index 0000000000000..30193eff93be2
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64}))" %s | FileCheck %s --check-prefixes=CHECK-64BIT,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=32}))" %s | FileCheck %s --check-prefixes=CHECK-32BIT,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64,32}))" %s | FileCheck %s --check-prefixes=CHECK-MULTI-ADDRSPACE,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint))" %s | FileCheck %s --check-prefixes=CHECK-DISABLED,CHECK-ALL
+
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_64bit
+// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr)
+llvm.func @test_inttoptr_ptrtoint_fold_64bit(%arg0: !llvm.ptr) -> !llvm.ptr {
+ // CHECK-64BIT-NOT: llvm.ptrtoint
+ // CHECK-64BIT-NOT: llvm.inttoptr
+ // CHECK-64BIT: llvm.return %[[ARG]]
+
+ // CHECK-32BIT-NOT: llvm.ptrtoint
+ // CHECK-32BIT-NOT: llvm.inttoptr
+ // CHECK-32BIT: llvm.return %[[ARG]]
+
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+ // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+ // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-DISABLED: llvm.return %[[PTR]]
+
+ %0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
+ %1 = llvm.inttoptr %0 : i64 to !llvm.ptr
+ llvm.return %1 : !llvm.ptr
+}
+
+// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_fold_64bit
+// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
+llvm.func @test_ptrtoint_inttoptr_fold_64bit(%arg0: i64) -> i64 {
+ // CHECK-64BIT-NOT: llvm.inttoptr
+ // CHECK-64BIT-NOT: llvm.ptrtoint
+ // CHECK-64BIT: llvm.return %[[ARG]]
+
+ // CHECK-32BIT: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+ // CHECK-32BIT: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+ // CHECK-32BIT: llvm.return %[[PTR]]
+
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+ // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+ // CHECK-DISABLED: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+ // CHECK-DISABLED: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+ // CHECK-DISABLED: llvm.return %[[PTR]]
+
+ %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
+ %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
+ llvm.return %1 : i64
+}
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_addrspace1_32bit
+// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
+llvm.func @test_inttoptr_ptrtoint_fold_addrspace1_32bit(%arg0: !llvm.ptr<1>) -> !llvm.ptr<1> {
+ // CHECK-64BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-64BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-64BIT: llvm.return %[[PTR]]
+
+ // CHECK-32BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-32BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-32BIT: llvm.return %[[PTR]]
+
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+ // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+ // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-DISABLED: llvm.return %[[PTR]]
+
+ %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i32
+ %1 = llvm.inttoptr %0 : i32 to !llvm.ptr<1>
+ llvm.return %1 : !llvm.ptr<1>
+}
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_type_mismatch
+// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
+llvm.func @test_inttoptr_ptrtoint_type_mismatch(%arg0: i64) -> i32 {
+ // CHECK-ALL: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+ // CHECK-ALL: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+ // CHECK-ALL: llvm.return %[[PTR]]
+
+ %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
+ %1 = llvm.ptrtoint %0 : !llvm.ptr to i32
+ llvm.return %1 : i32
+}
+
+// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_type_mismatch
+// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
+llvm.func @test_ptrtoint_inttoptr_type_mismatch(%arg0: !llvm.ptr<1>) -> !llvm.ptr<0> {
+ // CHECK-ALL: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-ALL: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-ALL: llvm.return %[[PTR]]
+ %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
+ %1 = llvm.inttoptr %0 : i64 to !llvm.ptr<0>
+ llvm.return %1 : !llvm.ptr<0>
+}
More information about the Mlir-commits
mailing list