[flang-commits] [flang] [Flang] Set address space during FIR pointer-like types lowering (PR #69599)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Thu Dec 14 06:43:21 PST 2023


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/69599

>From 42aca9238224dabd547c42cc8a895a1fa1126cd2 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 19 Oct 2023 12:49:35 +0100
Subject: [PATCH 1/2] [Flang] Set address space during FIR pointer-like types
 lowering

This patch modifies FIR pointer-like types lowering to LLVM dialect to use the
address space stated in the module's data layout.
---
 .../include/flang/Optimizer/CodeGen/TypeConverter.h  |  3 ++-
 flang/lib/Optimizer/CodeGen/TypeConverter.cpp        |  8 +++++++-
 flang/test/Fir/alloca-addrspace-2.fir                | 12 ++++++++++++
 flang/test/Fir/alloca-addrspace.fir                  | 12 ++++++++++++
 4 files changed, 33 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Fir/alloca-addrspace-2.fir
 create mode 100644 flang/test/Fir/alloca-addrspace.fir

diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 396c1363925554..d8072b57b6c94d 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -101,7 +101,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
   }
 
   template <typename A> mlir::Type convertPointerLike(A &ty) const {
-    return mlir::LLVM::LLVMPointerType::get(ty.getContext());
+    return mlir::LLVM::LLVMPointerType::get(ty.getContext(), addressSpace);
   }
 
   // convert a front-end kind value to either a std or LLVM IR dialect type
@@ -127,6 +127,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
   KindMapping kindMapping;
   std::unique_ptr<CodeGenSpecifics> specifics;
   std::unique_ptr<TBAABuilder> tbaaBuilder;
+  unsigned addressSpace;
 };
 
 } // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 209c586411f410..4ab283fb060c38 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -35,7 +35,13 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
                                       getTargetTriple(module),
                                       getKindMapping(module), dl)),
       tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
-                                                forceUnifiedTBAATree)) {
+                                                forceUnifiedTBAATree)),
+      addressSpace(0) {
+  // Get default alloca address space for the current target
+  if (mlir::Attribute addrSpace =
+          mlir::DataLayout(module).getAllocaMemorySpace())
+    addressSpace = addrSpace.cast<mlir::IntegerAttr>().getUInt();
+
   LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
 
   // Each conversion should return a value of type mlir::Type.
diff --git a/flang/test/Fir/alloca-addrspace-2.fir b/flang/test/Fir/alloca-addrspace-2.fir
new file mode 100644
index 00000000000000..8551cf8083635a
--- /dev/null
+++ b/flang/test/Fir/alloca-addrspace-2.fir
@@ -0,0 +1,12 @@
+// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s
+// RUN: tco --fir-to-llvm-ir %s | FileCheck %s
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>> } {
+  // CHECK-LABEL: llvm.func @set_addrspace
+  func.func @set_addrspace() {
+    // CHECK: llvm.alloca {{.*}} x i32
+    // CHECK-SAME: -> !llvm.ptr<i32, 5>
+    %0 = fir.alloca i32
+    return
+  }
+}
diff --git a/flang/test/Fir/alloca-addrspace.fir b/flang/test/Fir/alloca-addrspace.fir
new file mode 100644
index 00000000000000..20bf59b7a568d5
--- /dev/null
+++ b/flang/test/Fir/alloca-addrspace.fir
@@ -0,0 +1,12 @@
+// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s
+// RUN: tco --fir-to-llvm-ir %s | FileCheck %s
+
+module {
+  // CHECK-LABEL: llvm.func @default_addrspace
+  func.func @default_addrspace() {
+    // CHECK: llvm.alloca {{.*}} x i32
+    // CHECK-SAME: -> !llvm.ptr<i32>
+    %0 = fir.alloca i32
+    return
+  }
+}

>From 7f68601947da43452e9eee34a77ab4e1c6defa39 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 14 Dec 2023 14:42:49 +0000
Subject: [PATCH 2/2] Set address space in all relevant places

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 77 +++++++++++++------
 flang/lib/Optimizer/CodeGen/DescriptorModel.h | 58 ++++++++------
 flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 35 +++++----
 flang/test/Fir/alloca-addrspace-2.fir         |  2 +-
 flang/test/Fir/alloca-addrspace.fir           |  2 +-
 5 files changed, 109 insertions(+), 65 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index e07732d57880c5..02b0accc9d20fd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -67,8 +67,25 @@ static constexpr unsigned defaultAlign = 8;
 static constexpr unsigned kAttrPointer = CFI_attribute_pointer;
 static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable;
 
-static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context) {
-  return mlir::LLVM::LLVMPointerType::get(context);
+static inline unsigned getAddressSpace(mlir::ModuleOp module) {
+  if (mlir::Attribute addrSpace =
+          mlir::DataLayout(module).getAllocaMemorySpace())
+    return addrSpace.cast<mlir::IntegerAttr>().getUInt();
+
+  return 0u;
+}
+
+static inline unsigned
+getAddressSpace(mlir::ConversionPatternRewriter &rewriter) {
+  mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
+  return parentOp
+             ? ::getAddressSpace(parentOp->getParentOfType<mlir::ModuleOp>())
+             : 0u;
+}
+
+static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context,
+                                        unsigned addressSpace) {
+  return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
 }
 
 static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
@@ -197,7 +214,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
                               mlir::ConversionPatternRewriter &rewriter,
                               int boxValue) const {
     if (box.getType().isa<mlir::LLVM::LLVMPointerType>()) {
-      auto pty = ::getLlvmPtrType(resultTy.getContext());
+      auto pty =
+          ::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter));
       auto p = rewriter.create<mlir::LLVM::GEPOp>(
           loc, pty, boxTy.llvm, box,
           llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
@@ -278,7 +296,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
   mlir::Value
   getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
                      mlir::ConversionPatternRewriter &rewriter) const {
-    mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext());
+    mlir::Type resultTy =
+        ::getLlvmPtrType(boxTy.llvm.getContext(), ::getAddressSpace(rewriter));
     return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
   }
 
@@ -350,7 +369,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
                            mlir::ConversionPatternRewriter &rewriter,
                            mlir::Value base, ARGS... args) const {
     llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
-    auto llvmPtrTy = ::getLlvmPtrType(ty.getContext());
+    auto llvmPtrTy =
+        ::getLlvmPtrType(ty.getContext(), ::getAddressSpace(rewriter));
     return rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv);
   }
 
@@ -378,7 +398,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
     mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp);
     rewriter.setInsertionPointToStart(insertBlock);
     auto size = genI32Constant(loc, rewriter, 1);
-    mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext());
+    mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext(),
+                                            ::getAddressSpace(rewriter));
     auto al = rewriter.create<mlir::LLVM::AllocaOp>(
         loc, llvmPtrTy, llvmObjectTy, size, alignment);
     rewriter.restoreInsertionPoint(thisPt);
@@ -532,7 +553,8 @@ struct AllocaOpConversion : public FIROpConversion<fir::AllocaOp> {
         size = rewriter.create<mlir::LLVM::MulOp>(
             loc, ity, size, integerCast(loc, rewriter, ity, operands[i]));
     }
-    mlir::Type llvmPtrTy = ::getLlvmPtrType(alloc.getContext());
+    mlir::Type llvmPtrTy =
+        ::getLlvmPtrType(alloc.getContext(), ::getAddressSpace(rewriter));
     // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
     // pointers! Only propagate pinned and bindc_name to help debugging, but
     // this should have no functional purpose (and passing the operand segment
@@ -1167,9 +1189,10 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
   auto indexType = mlir::IntegerType::get(op.getContext(), 64);
   return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       rewriter.getUnknownLoc(), "malloc",
-      mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()),
-                                        indexType,
-                                        /*isVarArg=*/false));
+      mlir::LLVM::LLVMFunctionType::get(
+          getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)),
+          indexType,
+          /*isVarArg=*/false));
 }
 
 /// Helper function for generating the LLVM IR that computes the distance
@@ -1189,7 +1212,8 @@ computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
   // *)0 + 1)' trick for all types. The generated instructions are optimized
   // into constant by the first pass of InstCombine, so it should not be a
   // performance issue.
-  auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext());
+  auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext(),
+                                    ::getAddressSpace(rewriter));
   auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
   auto gep = rewriter.create<mlir::LLVM::GEPOp>(
       loc, llvmPtrTy, llvmObjectType, nullPtr,
@@ -1232,7 +1256,8 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
           loc, ity, size, integerCast(loc, rewriter, ity, opnd));
     heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
     rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
+        heap, ::getLlvmPtrType(heap.getContext(), ::getAddressSpace(rewriter)),
+        size, heap->getAttrs());
     return mlir::success();
   }
 
@@ -1258,9 +1283,10 @@ getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
   auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
   return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       rewriter.getUnknownLoc(), "free",
-      mlir::LLVM::LLVMFunctionType::get(voidType,
-                                        getLlvmPtrType(op.getContext()),
-                                        /*isVarArg=*/false));
+      mlir::LLVM::LLVMFunctionType::get(
+          voidType,
+          getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)),
+          /*isVarArg=*/false));
 }
 
 static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
@@ -1386,7 +1412,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
       return {getCharacterByteSize(loc, rewriter, charTy, lenParams),
               typeCodeVal};
     if (fir::isa_ref_type(boxEleTy)) {
-      auto ptrTy = ::getLlvmPtrType(rewriter.getContext());
+      auto ptrTy =
+          ::getLlvmPtrType(rewriter.getContext(), ::getAddressSpace(rewriter));
       return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal};
     }
     if (boxEleTy.isa<fir::RecordType>())
@@ -1447,7 +1474,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
                                 fir::RecordType recType) const {
     std::string name =
         fir::NameUniquer::getTypeDescriptorName(recType.getName());
-    mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext());
+    mlir::Type llvmPtrTy =
+        ::getLlvmPtrType(mod.getContext(), ::getAddressSpace(rewriter));
     if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) {
       return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
                                                       global.getSymName());
@@ -1505,7 +1533,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
             // Unlimited polymorphic type descriptor with no record type. Set
             // type descriptor address to a clean state.
             typeDesc = rewriter.create<mlir::LLVM::ZeroOp>(
-                loc, ::getLlvmPtrType(mod.getContext()));
+                loc, ::getLlvmPtrType(mod.getContext(),
+                                      ::getAddressSpace(rewriter)));
           }
         } else {
           typeDesc = getTypeDescriptor(mod, rewriter, loc,
@@ -1653,7 +1682,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
             loc, outterOffsetTy, gepArgs[0].get<mlir::Value>(), cast);
       }
     }
-    mlir::Type llvmPtrTy = ::getLlvmPtrType(resultTy.getContext());
+    mlir::Type llvmPtrTy =
+        ::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter));
     return rewriter.create<mlir::LLVM::GEPOp>(
         loc, llvmPtrTy, llvmBaseObjectType, base, gepArgs);
   }
@@ -2673,7 +2703,8 @@ struct CoordinateOpConversion
         getBaseAddrFromBox(loc, boxTyPair, boxBaseAddr, rewriter);
     // Component Type
     auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
-    mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
+    mlir::Type llvmPtrTy =
+        ::getLlvmPtrType(coor.getContext(), ::getAddressSpace(rewriter));
     mlir::Type byteTy = ::getI8Type(coor.getContext());
     mlir::LLVM::IntegerOverflowFlagsAttr nsw =
         mlir::LLVM::IntegerOverflowFlagsAttr::get(
@@ -2890,7 +2921,8 @@ struct TypeDescOpConversion : public FIROpConversion<fir::TypeDescOp> {
     auto module = typeDescOp.getOperation()->getParentOfType<mlir::ModuleOp>();
     std::string typeDescName =
         fir::NameUniquer::getTypeDescriptorName(recordType.getName());
-    auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext());
+    auto llvmPtrTy =
+        ::getLlvmPtrType(typeDescOp.getContext(), ::getAddressSpace(rewriter));
     if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
       rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
           typeDescOp, llvmPtrTy, global.getSymName());
@@ -3678,7 +3710,8 @@ struct BoxOffsetOpConversion : public FIROpConversion<fir::BoxOffsetOp> {
   matchAndRewrite(fir::BoxOffsetOp boxOffset, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
 
-    mlir::Type pty = ::getLlvmPtrType(boxOffset.getContext());
+    mlir::Type pty =
+        ::getLlvmPtrType(boxOffset.getContext(), ::getAddressSpace(rewriter));
     mlir::Type boxType = fir::unwrapRefType(boxOffset.getBoxRef().getType());
     mlir::Type llvmBoxTy =
         lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(boxType));
diff --git a/flang/lib/Optimizer/CodeGen/DescriptorModel.h b/flang/lib/Optimizer/CodeGen/DescriptorModel.h
index ed35caef930149..9f62f60596ac4d 100644
--- a/flang/lib/Optimizer/CodeGen/DescriptorModel.h
+++ b/flang/lib/Optimizer/CodeGen/DescriptorModel.h
@@ -31,7 +31,7 @@
 
 namespace fir {
 
-using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
+using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *, unsigned);
 
 /// Get the LLVM IR dialect model for building a particular C++ type, `T`.
 template <typename T>
@@ -39,64 +39,72 @@ TypeBuilderFunc getModel();
 
 template <>
 TypeBuilderFunc getModel<void *>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::LLVM::LLVMPointerType::get(context);
+  return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type {
+    return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
   };
 }
 template <>
 TypeBuilderFunc getModel<unsigned>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
-  };
+  return
+      [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
+        return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
+      };
 }
 template <>
 TypeBuilderFunc getModel<int>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::IntegerType::get(context, sizeof(int) * 8);
-  };
+  return
+      [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
+        return mlir::IntegerType::get(context, sizeof(int) * 8);
+      };
 }
 template <>
 TypeBuilderFunc getModel<unsigned long>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
-  };
+  return
+      [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
+        return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
+      };
 }
 template <>
 TypeBuilderFunc getModel<unsigned long long>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
-  };
+  return
+      [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
+        return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
+      };
 }
 template <>
 TypeBuilderFunc getModel<long long>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::IntegerType::get(context, sizeof(long long) * 8);
-  };
+  return
+      [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
+        return mlir::IntegerType::get(context, sizeof(long long) * 8);
+      };
 }
 template <>
 TypeBuilderFunc getModel<Fortran::ISO::CFI_rank_t>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
+  return [](mlir::MLIRContext *context,
+            unsigned /*addressSpace*/) -> mlir::Type {
     return mlir::IntegerType::get(context,
                                   sizeof(Fortran::ISO::CFI_rank_t) * 8);
   };
 }
 template <>
 TypeBuilderFunc getModel<Fortran::ISO::CFI_type_t>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
+  return [](mlir::MLIRContext *context,
+            unsigned /*addressSpace*/) -> mlir::Type {
     return mlir::IntegerType::get(context,
                                   sizeof(Fortran::ISO::CFI_type_t) * 8);
   };
 }
 template <>
 TypeBuilderFunc getModel<long>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::IntegerType::get(context, sizeof(long) * 8);
-  };
+  return
+      [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
+        return mlir::IntegerType::get(context, sizeof(long) * 8);
+      };
 }
 template <>
 TypeBuilderFunc getModel<Fortran::ISO::CFI_dim_t>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context);
+  return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type {
+    auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context, addressSpace);
     return mlir::LLVM::LLVMArrayType::get(indexTy, 3);
   };
 }
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 4ab283fb060c38..7404ee6c6244dc 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -73,7 +73,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
   addConversion([&](fir::LenType field) {
     // Get size of len paramter from the descriptor.
     return getModel<Fortran::runtime::typeInfo::TypeParameterValue>()(
-        &getContext());
+        &getContext(), addressSpace);
   });
   addConversion([&](fir::LogicalType boolTy) {
     return mlir::IntegerType::get(
@@ -220,25 +220,25 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
     dataDescFields.push_back(eleTy);
   else
     dataDescFields.push_back(
-        mlir::LLVM::LLVMPointerType::get(eleTy.getContext()));
+        mlir::LLVM::LLVMPointerType::get(eleTy.getContext(), addressSpace));
   // elem_len
   dataDescFields.push_back(
-      getDescFieldTypeModel<kElemLenPosInBox>()(&getContext()));
+      getDescFieldTypeModel<kElemLenPosInBox>()(&getContext(), addressSpace));
   // version
   dataDescFields.push_back(
-      getDescFieldTypeModel<kVersionPosInBox>()(&getContext()));
+      getDescFieldTypeModel<kVersionPosInBox>()(&getContext(), addressSpace));
   // rank
   dataDescFields.push_back(
-      getDescFieldTypeModel<kRankPosInBox>()(&getContext()));
+      getDescFieldTypeModel<kRankPosInBox>()(&getContext(), addressSpace));
   // type
   dataDescFields.push_back(
-      getDescFieldTypeModel<kTypePosInBox>()(&getContext()));
+      getDescFieldTypeModel<kTypePosInBox>()(&getContext(), addressSpace));
   // attribute
   dataDescFields.push_back(
-      getDescFieldTypeModel<kAttributePosInBox>()(&getContext()));
+      getDescFieldTypeModel<kAttributePosInBox>()(&getContext(), addressSpace));
   // f18Addendum
-  dataDescFields.push_back(
-      getDescFieldTypeModel<kF18AddendumPosInBox>()(&getContext()));
+  dataDescFields.push_back(getDescFieldTypeModel<kF18AddendumPosInBox>()(
+      &getContext(), addressSpace));
   // [dims]
   if (rank == unknownRank()) {
     if (auto seqTy = ele.dyn_cast<SequenceType>())
@@ -247,15 +247,17 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
       rank = 0;
   }
   if (rank > 0) {
-    auto rowTy = getDescFieldTypeModel<kDimsPosInBox>()(&getContext());
+    auto rowTy =
+        getDescFieldTypeModel<kDimsPosInBox>()(&getContext(), addressSpace);
     dataDescFields.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, rank));
   }
   // opt-type-ptr: i8* (see fir.tdesc)
   if (requiresExtendedDesc(ele) || fir::isUnlimitedPolymorphicType(box)) {
     dataDescFields.push_back(
-        getExtendedDescFieldTypeModel<kOptTypePtrPosInBox>()(&getContext()));
-    auto rowTy =
-        getExtendedDescFieldTypeModel<kOptRowTypePosInBox>()(&getContext());
+        getExtendedDescFieldTypeModel<kOptTypePtrPosInBox>()(&getContext(),
+                                                             addressSpace));
+    auto rowTy = getExtendedDescFieldTypeModel<kOptRowTypePosInBox>()(
+        &getContext(), addressSpace);
     dataDescFields.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, 1));
     if (auto recTy = fir::unwrapSequenceType(ele).dyn_cast<fir::RecordType>())
       if (recTy.getNumLenParams() > 0) {
@@ -278,13 +280,14 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
 mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const {
   // TODO: send the box type and the converted LLVM structure layout
   // to tbaaBuilder for proper creation of TBAATypeDescriptorOp.
-  return mlir::LLVM::LLVMPointerType::get(box.getContext());
+  return mlir::LLVM::LLVMPointerType::get(box.getContext(), addressSpace);
 }
 
 // fir.boxproc<any>  -->  llvm<"{ any*, i8* }">
 mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const {
   auto funcTy = convertType(boxproc.getEleTy());
-  auto voidPtrTy = mlir::LLVM::LLVMPointerType::get(boxproc.getContext());
+  auto voidPtrTy =
+      mlir::LLVM::LLVMPointerType::get(boxproc.getContext(), addressSpace);
   llvm::SmallVector<mlir::Type, 2> tuple = {funcTy, voidPtrTy};
   return mlir::LLVM::LLVMStructType::getLiteral(boxproc.getContext(), tuple,
                                                 /*isPacked=*/false);
@@ -335,7 +338,7 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const {
 // the f18 object v. class distinction (F2003).
 mlir::Type
 LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const {
-  return mlir::LLVM::LLVMPointerType::get(ctx);
+  return mlir::LLVM::LLVMPointerType::get(ctx, addressSpace);
 }
 
 // Relay TBAA tag attachment to TBAABuilder.
diff --git a/flang/test/Fir/alloca-addrspace-2.fir b/flang/test/Fir/alloca-addrspace-2.fir
index 8551cf8083635a..6ba23630dba138 100644
--- a/flang/test/Fir/alloca-addrspace-2.fir
+++ b/flang/test/Fir/alloca-addrspace-2.fir
@@ -5,7 +5,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_mem
   // CHECK-LABEL: llvm.func @set_addrspace
   func.func @set_addrspace() {
     // CHECK: llvm.alloca {{.*}} x i32
-    // CHECK-SAME: -> !llvm.ptr<i32, 5>
+    // CHECK-SAME: -> !llvm.ptr<5>
     %0 = fir.alloca i32
     return
   }
diff --git a/flang/test/Fir/alloca-addrspace.fir b/flang/test/Fir/alloca-addrspace.fir
index 20bf59b7a568d5..a5f3a18355ad3a 100644
--- a/flang/test/Fir/alloca-addrspace.fir
+++ b/flang/test/Fir/alloca-addrspace.fir
@@ -5,7 +5,7 @@ module {
   // CHECK-LABEL: llvm.func @default_addrspace
   func.func @default_addrspace() {
     // CHECK: llvm.alloca {{.*}} x i32
-    // CHECK-SAME: -> !llvm.ptr<i32>
+    // CHECK-SAME: -> !llvm.ptr
     %0 = fir.alloca i32
     return
   }



More information about the flang-commits mailing list