[Mlir-commits] [mlir] [MLIR][LLVM] Have LLVM::AddressOfOp implement ConstantLike (PR #90481)

Johannes de Fine Licht llvmlistbot at llvm.org
Mon Apr 29 08:55:31 PDT 2024


https://github.com/definelicht updated https://github.com/llvm/llvm-project/pull/90481

>From 3c1ea9ed5f14ff572c71255b0baca2698a0bea92 Mon Sep 17 00:00:00 2001
From: Johannes de Fine Licht <johannes.definelicht at nextsilicon.com>
Date: Mon, 29 Apr 2024 14:57:47 +0000
Subject: [PATCH] [MLIR][LLVM] Have LLVM::AddressOfOp implement ConstantLike.

For all intents and purposes llvm.mlir.addressof acts like a constant,
and should be treated as such by passes. In particular, the operation
should be propagated rather than passed whenever possible.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  8 +--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 13 ++++-
 .../test/Dialect/LLVMIR/constant-folding.mlir | 50 +++++++++++++++++++
 3 files changed, 67 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index eedae4b9bb7c8e..6655ce6f123e14 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -63,12 +63,12 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
   let arguments = !con(commonArgs, iofArg);
 
   let builders = [
-    OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs, 
+    OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
                    "IntegerOverflowFlags":$overflowFlags), [{
       build($_builder, $_state, type, lhs, rhs);
       $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
     }]>,
-    OpBuilder<(ins "Value":$lhs, "Value":$rhs, 
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
                    "IntegerOverflowFlags":$overflowFlags), [{
       build($_builder, $_state, lhs, rhs);
       $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
@@ -1052,7 +1052,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
 ////////////////////////////////////////////////////////////////////////////////
 
 def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
-    [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+    [Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let arguments = (ins FlatSymbolRefAttr:$global_name);
   let results = (outs LLVM_AnyPointer:$res);
 
@@ -1114,6 +1114,8 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
   }];
 
   let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))";
+
+  let hasFolder = 1;
 }
 
 def LLVM_GlobalOp : LLVM_Op<"mlir.global",
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4e06b9c127e76a..4335b66f2ef37f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1785,7 +1785,7 @@ LogicalResult ReturnOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// Verifier for LLVM::AddressOfOp.
+// LLVM::AddressOfOp.
 //===----------------------------------------------------------------------===//
 
 static Operation *parentLLVMModule(Operation *op) {
@@ -1826,6 +1826,11 @@ AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+// AddressOp constant-folds to the global symbol name.
+OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
+  return getGlobalNameAttr();
+}
+
 //===----------------------------------------------------------------------===//
 // Verifier for LLVM::ComdatOp.
 //===----------------------------------------------------------------------===//
@@ -3258,6 +3263,12 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
 
 Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                             Type type, Location loc) {
+  // If this was folded from an llvm.mlir.addressof operation, it should be
+  // materialized as such.
+  if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
+    if (isa<LLVM::LLVMPointerType>(type))
+      return builder.create<LLVM::AddressOfOp>(loc, type, symbol);
+  // Otherwise try materializing it as a regular llvm.mlir.constant op.
   return LLVM::ConstantOp::materialize(builder, value, type, loc);
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
index f800f2690467da..454126321eb970 100644
--- a/mlir/test/Dialect/LLVMIR/constant-folding.mlir
+++ b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
@@ -51,3 +51,53 @@ llvm.func @or_basic() -> i32 {
   // CHECK: llvm.return %[[RES]] : i32
   llvm.return %2 : i32
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @addressof
+llvm.func @addressof() {
+  // CHECK-NEXT: %[[ADDRESSOF:.+]] = llvm.mlir.addressof @foo
+  %0 = llvm.mlir.addressof @foo : !llvm.ptr
+  %1 = llvm.mlir.addressof @foo : !llvm.ptr
+  // CHECK-NEXT: llvm.call @bar(%[[ADDRESSOF]], %[[ADDRESSOF]])
+  llvm.call @bar(%0, %1) : (!llvm.ptr, !llvm.ptr) -> ()
+  // CHECK-NEXT: llvm.return
+  llvm.return
+}
+
+llvm.mlir.global constant @foo() : i32
+
+llvm.func @bar(!llvm.ptr, !llvm.ptr)
+
+// -----
+
+// CHECK-LABEL: llvm.func @addressof_select
+llvm.func @addressof_select(%arg: i1) -> !llvm.ptr {
+  // CHECK-NEXT: %[[ADDRESSOF:.+]] = llvm.mlir.addressof @foo
+  %0 = llvm.mlir.addressof @foo : !llvm.ptr
+  %1 = llvm.mlir.addressof @foo : !llvm.ptr
+  %2 = arith.select %arg, %0, %1 : !llvm.ptr
+  // CHECK-NEXT: llvm.return %[[ADDRESSOF]]
+  llvm.return %2 : !llvm.ptr
+}
+
+llvm.mlir.global constant @foo() : i32
+
+llvm.func @bar(!llvm.ptr, !llvm.ptr)
+
+// -----
+
+// CHECK-LABEL: llvm.func @addressof_blocks
+llvm.func @addressof_blocks(%arg: i1) -> !llvm.ptr {
+  // CHECK-NEXT: %[[ADDRESSOF:.+]] = llvm.mlir.addressof @foo
+  llvm.cond_br %arg, ^bb1, ^bb2
+^bb1:
+  %0 = llvm.mlir.addressof @foo : !llvm.ptr
+  llvm.return %0 : !llvm.ptr
+^bb2:
+  %1 = llvm.mlir.addressof @foo : !llvm.ptr
+  // CHECK: return %[[ADDRESSOF]]
+  llvm.return %1 : !llvm.ptr
+}
+
+llvm.mlir.global constant @foo() : i32



More information about the Mlir-commits mailing list