[Mlir-commits] [mlir] 499abb2 - Add generic type attribute mapping infrastructure, use it in GpuToX

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Feb 9 10:00:51 PST 2023


Author: Krzysztof Drewniak
Date: 2023-02-09T18:00:46Z
New Revision: 499abb243cb75262f121659b87f5a2a6a7c8a82f

URL: https://github.com/llvm/llvm-project/commit/499abb243cb75262f121659b87f5a2a6a7c8a82f
DIFF: https://github.com/llvm/llvm-project/commit/499abb243cb75262f121659b87f5a2a6a7c8a82f.diff

LOG: Add generic type attribute mapping infrastructure, use it in GpuToX

Remapping memory spaces is a function often needed in type
conversions, most often when going to LLVM or to/from SPIR-V (a future
commit), and it is possible that such remappings may become more
common in the future as dialects take advantage of the more generic
memory space infrastructure.

Currently, memory space remappings are handled by running a
special-purpose conversion pass before the main conversion that
changes the address space attributes. In this commit, this approach is
replaced by adding a notion of type attribute conversions
TypeConverter, which is then used to convert memory space attributes.

Then, we use this infrastructure throughout the *ToLLVM conversions.
This has the advantage of loosing the requirements on the inputs to
those passes from "all address spaces must be integers" to "all
memory spaces must be convertible to integer spaces", a looser
requirement that reduces the coupling between portions of MLIR.

ON top of that, this change leads to the removal of most of the calls
to getMemorySpaceAsInt(), bringing us closer to removing it.

(A rework of the SPIR-V conversions to use this new system will be in
a folowup commit.)

As a note, one long-term motivation for this change is that I would
eventually like to add an allocaMemorySpace key to MLIR data layouts
and then call getMemRefAddressSpace(allocaMemorySpace) in the
relevant *ToLLVM in order to ensure all alloca()s, whether incoming or
produces during the LLVM lowering, have the correct address space for
a given target.

I expect that the type attribute conversion system may be useful in
other contexts.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D142159

Added: 
    mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir
    mlir/test/Conversion/MemRefToLLVM/invalid.mlir

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
    mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 85f1c698beffc..b13b88d6773a8 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -147,6 +147,11 @@ class LLVMTypeConverter : public TypeConverter {
   unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
                                            const DataLayout &layout);
 
+  /// Return the LLVM address space corresponding to the memory space of the
+  /// memref type `type` or failure if the memory space cannot be converted to
+  /// an integer.
+  FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type);
+
   /// Check if a memref type can be converted to a bare pointer.
   static bool canConvertToBarePtr(BaseMemRefType type);
 

diff  --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
index f08f9fb59dee5..13c368676cf4a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
@@ -25,7 +25,8 @@ def AMDGPU_Dialect : Dialect {
 
 
   let dependentDialects = [
-    "arith::ArithDialect"
+    "arith::ArithDialect",
+    "gpu::GPUDialect"
   ];
   let useDefaultAttributePrinterParser = 1;
 }

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 594d8be0838b7..e3eee6de84879 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -61,23 +61,6 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
 }
 
 namespace gpu {
-/// A function that maps a MemorySpace enum to a target-specific integer value.
-using MemorySpaceMapping =
-    std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
-
-/// Populates type conversion rules for lowering memory space attributes to
-/// numeric values.
-void populateMemorySpaceAttributeTypeConversions(
-    TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
-
-/// Populates patterns to lower memory space attributes to numeric values.
-void populateMemorySpaceLoweringPatterns(TypeConverter &typeConverter,
-                                         RewritePatternSet &patterns);
-
-/// Populates legality rules for lowering memory space attriutes to numeric
-/// values.
-void populateLowerMemorySpaceOpLegality(ConversionTarget &target);
-
 /// Returns the default annotation name for GPU binary blobs.
 std::string getDefaultGpuBinaryAnnotation();
 

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index fae2f0f37fc95..a144fa4127ddf 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -37,23 +37,4 @@ def GpuMapParallelLoopsPass
   let dependentDialects = ["mlir::gpu::GPUDialect"];
 }
 
-def GPULowerMemorySpaceAttributesPass
-    : Pass<"gpu-lower-memory-space-attributes"> {
-  let summary = "Assign numeric values to memref memory space symbolic placeholders";
-  let description = [{
-    Updates all memref types that have a memory space attribute
-    that is a `gpu::AddressSpaceAttr`. These attributes are
-    changed to `IntegerAttr`'s using a mapping that is given in the
-    options.
-  }];
-  let options = [
-    Option<"privateAddrSpace", "private", "unsigned", "5",
-      "private address space numeric value">,
-    Option<"workgroupAddrSpace", "workgroup", "unsigned", "3",
-      "workgroup address space numeric value">,
-    Option<"globalAddrSpace", "global", "unsigned", "1",
-      "global address space numeric value">
-  ];
-}
-
 #endif // MLIR_DIALECT_GPU_PASSES

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9e10e3f96d54f..c592f2db999a2 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -21,6 +21,7 @@
 namespace mlir {
 
 // Forward declarations.
+class Attribute;
 class Block;
 class ConversionPatternRewriter;
 class MLIRContext;
@@ -87,6 +88,34 @@ class TypeConverter {
     SmallVector<Type, 4> argTypes;
   };
 
+  /// The general result of a type attribute conversion callback, allowing
+  /// for early termination. The default constructor creates the na case.
+  class AttributeConversionResult {
+  public:
+    constexpr AttributeConversionResult() : impl() {}
+    AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
+
+    static AttributeConversionResult result(Attribute attr);
+    static AttributeConversionResult na();
+    static AttributeConversionResult abort();
+
+    bool hasResult() const;
+    bool isNa() const;
+    bool isAbort() const;
+
+    Attribute getResult() const;
+
+  private:
+    AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
+
+    llvm::PointerIntPair<Attribute, 2> impl;
+    // Note that na is 0 so that we can use PointerIntPair's default
+    // constructor.
+    static constexpr unsigned naTag = 0;
+    static constexpr unsigned resultTag = 1;
+    static constexpr unsigned abortTag = 2;
+  };
+
   /// Register a conversion function. A conversion function must be convertible
   /// to any of the following forms(where `T` is a class derived from `Type`:
   ///   * std::optional<Type>(T)
@@ -156,6 +185,34 @@ class TypeConverter {
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
 
+  /// Register a conversion function for attributes within types. Type
+  /// converters may call this function in order to allow hoking into the
+  /// translation of attributes that exist within types. For example, a type
+  /// converter for the `memref` type could use these conversions to convert
+  /// memory spaces or layouts in an extensible way.
+  ///
+  /// The conversion functions take a non-null Type or subclass of Type and a
+  /// non-null Attribute (or subclass of Attribute), and returns a
+  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
+  /// which may be `nullptr`, representing the conversion's success,
+  /// `AttributeConversionResult::na()` (the default empty value), indicating
+  /// that the conversion function did not apply and that further conversion
+  /// functions should be checked, or `AttributeConversionResult::abort()`
+  /// indicating that the conversion process should be aborted.
+  ///
+  /// Registered conversion functions are callled in the reverse of the order in
+  /// which they were registered.
+  template <
+      typename FnT,
+      typename T =
+          typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
+      typename A =
+          typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
+  void addTypeAttributeConversion(FnT &&callback) {
+    registerTypeAttributeConversion(
+        wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
+  }
+
   /// Convert the given type. This function should return failure if no valid
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
@@ -226,6 +283,12 @@ class TypeConverter {
                                  resultType, inputs);
   }
 
+  /// Convert an attribute present `attr` from within the type `type` using
+  /// the registered conversion functions. If no applicable conversion has been
+  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
+  /// is a valid return value for this function.
+  std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
+
 private:
   /// The signature of the callback used to convert a type. If the new set of
   /// types is empty, the type is removed and any usages of the existing value
@@ -237,6 +300,10 @@ class TypeConverter {
   using MaterializationCallbackFn = std::function<std::optional<Value>(
       OpBuilder &, Type, ValueRange, Location)>;
 
+  /// The signature of the callback used to convert a type attribute.
+  using TypeAttributeConversionCallbackFn =
+      std::function<AttributeConversionResult(Type, Attribute)>;
+
   /// Attempt to materialize a conversion using one of the provided
   /// materialization functions.
   Value materializeConversion(
@@ -311,6 +378,32 @@ class TypeConverter {
     };
   }
 
+  /// Generate a wrapper for the given memory space conversion callback. The
+  /// callback may take any subclass of `Attribute` and the wrapper will check
+  /// for the target attribute to be of the expected class before calling the
+  /// callback.
+  template <typename T, typename A, typename FnT>
+  TypeAttributeConversionCallbackFn
+  wrapTypeAttributeConversion(FnT &&callback) {
+    return [callback = std::forward<FnT>(callback)](
+               Type type, Attribute attr) -> AttributeConversionResult {
+      if (T derivedType = type.dyn_cast<T>()) {
+        if (A derivedAttr = attr.dyn_cast_or_null<A>())
+          return callback(derivedType, derivedAttr);
+      }
+      return AttributeConversionResult::na();
+    };
+  }
+
+  /// Register a memory space conversion, clearing caches.
+  void
+  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
+    typeAttributeConversions.emplace_back(std::move(callback));
+    // Clear type conversions in case a memory space is lingering inside.
+    cachedDirectConversions.clear();
+    cachedMultiConversions.clear();
+  }
+
   /// The set of registered conversion functions.
   SmallVector<ConversionCallbackFn, 4> conversions;
 
@@ -319,6 +412,9 @@ class TypeConverter {
   SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
   SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
 
+  /// The list of registered type attribute conversion functions.
+  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
+
   /// A set of cached conversions to avoid recomputing in the common case.
   /// Direct 1-1 conversions are the most common, so this cache stores the
   /// successful 1-1 conversions as well as all failed conversions.

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 48c0cbf379880..636e0d54dd272 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -8,6 +8,7 @@
 
 #include "GPUOpsLowering.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/STLExtras.h"
@@ -474,3 +475,18 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
   rewriter.replaceOp(op, result);
   return success();
 }
+
+static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
+  return IntegerAttr::get(IntegerType::get(ctx, 64), space);
+}
+
+void mlir::populateGpuMemorySpaceAttributeConversions(
+    TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
+  typeConverter.addTypeAttributeConversion(
+      [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
+        gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
+        unsigned addressSpace = mapping(memorySpace);
+        return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
+                                      addressSpace);
+      });
+}

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 55efb2230ad93..1895c65d99964 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -112,6 +112,14 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
   }
 };
 
+/// A function that maps a MemorySpace enum to a target-specific integer value.
+using MemorySpaceMapping =
+    std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
+
+/// Populates memory space attribute conversion rules for lowering
+/// gpu.address_space to integer values.
+void populateGpuMemorySpaceAttributeConversions(
+    TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 08ea2cae4cd7c..1cde673b5b000 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -241,38 +241,26 @@ struct LowerGpuOpsToNVVMOpsPass
         return signalPassFailure();
     }
 
-    // MemRef conversion for GPU to NVVM lowering.
-    {
-      RewritePatternSet patterns(m.getContext());
-      TypeConverter typeConverter;
-      typeConverter.addConversion([](Type t) { return t; });
-      // NVVM uses alloca in the default address space to represent private
-      // memory allocations, so drop private annotations. NVVM uses address
-      // space 3 for shared memory. NVVM uses the default address space to
-      // represent global memory.
-      gpu::populateMemorySpaceAttributeTypeConversions(
-          typeConverter, [](gpu::AddressSpace space) -> unsigned {
-            switch (space) {
-            case gpu::AddressSpace::Global:
-              return static_cast<unsigned>(
-                  NVVM::NVVMMemorySpace::kGlobalMemorySpace);
-            case gpu::AddressSpace::Workgroup:
-              return static_cast<unsigned>(
-                  NVVM::NVVMMemorySpace::kSharedMemorySpace);
-            case gpu::AddressSpace::Private:
-              return 0;
-            }
-            llvm_unreachable("unknown address space enum value");
-            return 0;
-          });
-      gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
-      ConversionTarget target(getContext());
-      gpu::populateLowerMemorySpaceOpLegality(target);
-      if (failed(applyFullConversion(m, target, std::move(patterns))))
-        return signalPassFailure();
-    }
-
     LLVMTypeConverter converter(m.getContext(), options);
+    // NVVM uses alloca in the default address space to represent private
+    // memory allocations, so drop private annotations. NVVM uses address
+    // space 3 for shared memory. NVVM uses the default address space to
+    // represent global memory.
+    populateGpuMemorySpaceAttributeConversions(
+        converter, [](gpu::AddressSpace space) -> unsigned {
+          switch (space) {
+          case gpu::AddressSpace::Global:
+            return static_cast<unsigned>(
+                NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+          case gpu::AddressSpace::Workgroup:
+            return static_cast<unsigned>(
+                NVVM::NVVMMemorySpace::kSharedMemorySpace);
+          case gpu::AddressSpace::Private:
+            return 0;
+          }
+          llvm_unreachable("unknown address space enum value");
+          return 0;
+        });
     // Lowering for MMAMatrixType.
     converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
       return convertMMAToLLVMType(type);

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 233e2fb3681d1..3c154bc720765 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -132,33 +132,21 @@ struct LowerGpuOpsToROCDLOpsPass
       (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
     }
 
-    // Apply memory space lowering. The target uses 3 for workgroup memory and 5
-    // for private memory.
-    {
-      RewritePatternSet patterns(ctx);
-      TypeConverter typeConverter;
-      typeConverter.addConversion([](Type t) { return t; });
-      gpu::populateMemorySpaceAttributeTypeConversions(
-          typeConverter, [](gpu::AddressSpace space) {
-            switch (space) {
-            case gpu::AddressSpace::Global:
-              return 1;
-            case gpu::AddressSpace::Workgroup:
-              return 3;
-            case gpu::AddressSpace::Private:
-              return 5;
-            }
-            llvm_unreachable("unknown address space enum value");
-            return 0;
-          });
-      ConversionTarget target(getContext());
-      gpu::populateLowerMemorySpaceOpLegality(target);
-      gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
-      if (failed(applyFullConversion(m, target, std::move(patterns))))
-        return signalPassFailure();
-    }
-
     LLVMTypeConverter converter(ctx, options);
+    populateGpuMemorySpaceAttributeConversions(
+        converter, [](gpu::AddressSpace space) {
+          switch (space) {
+          case gpu::AddressSpace::Global:
+            return 1;
+          case gpu::AddressSpace::Workgroup:
+            return 3;
+          case gpu::AddressSpace::Private:
+            return 5;
+          }
+          llvm_unreachable("unknown address space enum value");
+          return 0;
+        });
+
     RewritePatternSet llvmPatterns(ctx);
 
     mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 09e68b3cc5e5e..69a0172ebf0b1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -112,8 +112,10 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
   auto elementType = type.getElementType();
   auto structElementType = typeConverter->convertType(elementType);
-  return getTypeConverter()->getPointerType(structElementType,
-                                            type.getMemorySpaceAsInt());
+  auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
+  if (failed(addressSpace))
+    return {};
+  return getTypeConverter()->getPointerType(structElementType, *addressSpace);
 }
 
 void ConvertToLLVMPattern::getMemRefDescriptorSizes(

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index f15c9f2a40e13..1b640bfd83a17 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -158,6 +158,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
+
+  // Integer memory spaces map to themselves.
+  addTypeAttributeConversion(
+      [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
 }
 
 /// Returns the MLIR context.
@@ -318,8 +322,17 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
   if (!elementType)
     return {};
 
-  LLVM::LLVMPointerType ptrTy =
-      getPointerType(elementType, type.getMemorySpaceAsInt());
+  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
+  if (failed(addressSpace)) {
+    emitError(UnknownLoc::get(type.getContext()),
+              "conversion of memref memory space ")
+        << type.getMemorySpace()
+        << " to integer address space "
+           "failed. Consider adding memory space conversions.";
+    return {};
+  }
+  auto ptrTy = getPointerType(elementType, *addressSpace);
+
   auto indexTy = getIndexType();
 
   SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
@@ -337,7 +350,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
 unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
                                                     const DataLayout &layout) {
   // Compute the descriptor size given that of its components indicated above.
-  unsigned space = type.getMemorySpaceAsInt();
+  unsigned space = *getMemRefAddressSpace(type);
   return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
          (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
 }
@@ -369,7 +382,7 @@ unsigned
 LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
                                                    const DataLayout &layout) {
   // Compute the descriptor size given that of its components indicated above.
-  unsigned space = type.getMemorySpaceAsInt();
+  unsigned space = *getMemRefAddressSpace(type);
   return layout.getTypeSize(getIndexType()) +
          llvm::divideCeil(getPointerBitwidth(space), 8);
 }
@@ -381,6 +394,21 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
                                           getUnrankedMemRefDescriptorFields());
 }
 
+FailureOr<unsigned>
+LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
+  if (!type.getMemorySpace()) // Default memory space -> 0.
+    return 0;
+  Optional<Attribute> converted =
+      convertTypeAttribute(type, type.getMemorySpace());
+  if (!converted)
+    return failure();
+  if (!(*converted)) // Conversion to default is 0.
+    return 0;
+  if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
+    return explicitSpace.getInt();
+  return failure();
+}
+
 // Check if a memref type can be converted to a bare pointer.
 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
   if (type.isa<UnrankedMemRefType>())
@@ -412,7 +440,10 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
   Type elementType = convertType(type.getElementType());
   if (!elementType)
     return {};
-  return getPointerType(elementType, type.getMemorySpaceAsInt());
+  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
+  if (failed(addressSpace))
+    return {};
+  return getPointerType(elementType, *addressSpace);
 }
 
 /// Convert an n-D vector type to an LLVM vector type:

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 4afe2ffebb342..e29c5d2565770 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -59,11 +59,12 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
                                  MemRefType memRefType, Type elementPtrType,
                                  LLVMTypeConverter &typeConverter) {
   auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>();
-  if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt())
+  unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType);
+  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
     allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
         loc,
         typeConverter.getPointerType(allocatedPtrTy.getElementType(),
-                                     memRefType.getMemorySpaceAsInt()),
+                                     memrefAddrSpace),
         allocatedPtr);
 
   if (!typeConverter.useOpaquePointers())

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d6097f057db40..c3ff7d86852ae 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -96,8 +96,10 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
     auto allocaOp = cast<memref::AllocaOp>(op);
     auto elementType =
         typeConverter->convertType(allocaOp.getType().getElementType());
-    auto elementPtrType = getTypeConverter()->getPointerType(
-        elementType, allocaOp.getType().getMemorySpaceAsInt());
+    unsigned addrSpace =
+        *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
+    auto elementPtrType =
+        getTypeConverter()->getPointerType(elementType, addrSpace);
 
     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
         loc, elementPtrType, elementType, sizeBytes,
@@ -400,10 +402,11 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Type operandType = dimOp.getSource().getType();
     if (operandType.isa<UnrankedMemRefType>()) {
-      rewriter.replaceOp(
-          dimOp, {extractSizeOfUnrankedMemRef(
-                     operandType, dimOp, adaptor.getOperands(), rewriter)});
-
+      FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
+          operandType, dimOp, adaptor.getOperands(), rewriter);
+      if (failed(extractedSize))
+        return failure();
+      rewriter.replaceOp(dimOp, {*extractedSize});
       return success();
     }
     if (operandType.isa<MemRefType>()) {
@@ -416,15 +419,23 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
   }
 
 private:
-  Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
-                                    OpAdaptor adaptor,
-                                    ConversionPatternRewriter &rewriter) const {
+  FailureOr<Value>
+  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
+                              OpAdaptor adaptor,
+                              ConversionPatternRewriter &rewriter) const {
     Location loc = dimOp.getLoc();
 
     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
     auto scalarMemRefType =
         MemRefType::get({}, unrankedMemRefType.getElementType());
-    unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
+    FailureOr<unsigned> maybeAddressSpace =
+        getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
+    if (failed(maybeAddressSpace)) {
+      dimOp.emitOpError("memref memory space must be convertible to an integer "
+                        "address space");
+      return failure();
+    }
+    unsigned addressSpace = *maybeAddressSpace;
 
     // Extract pointer to the underlying ranked descriptor and bitcast it to a
     // memref<element_type> descriptor pointer to minimize the number of GEP
@@ -455,8 +466,9 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
     Value sizePtr = rewriter.create<LLVM::GEPOp>(
         loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
         idxPlusOne);
-    return rewriter.create<LLVM::LoadOp>(
-        loc, getTypeConverter()->getIndexType(), sizePtr);
+    return rewriter
+        .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
+        .getResult();
   }
 
   std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
@@ -670,10 +682,14 @@ struct GlobalMemrefOpLowering
     }
 
     uint64_t alignment = global.getAlignment().value_or(0);
-
+    FailureOr<unsigned> addressSpace =
+        getTypeConverter()->getMemRefAddressSpace(type);
+    if (failed(addressSpace))
+      return global.emitOpError(
+          "memory space cannot be converted to an integer address space");
     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
-        initialValue, alignment, type.getMemorySpaceAsInt());
+        initialValue, alignment, *addressSpace);
     if (!global.isExternal() && global.isUninitialized()) {
       Block *blk = new Block();
       newGlobal.getInitializerRegion().push_back(blk);
@@ -701,7 +717,10 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
                                           Operation *op) const override {
     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
-    unsigned memSpace = type.getMemorySpaceAsInt();
+
+    // This is called after a type conversion, which would have failed if this
+    // call fails.
+    unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type);
 
     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
@@ -1097,8 +1116,9 @@ static void extractPointersAndOffset(Location loc,
     return;
   }
 
-  unsigned memorySpace =
-      operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
+  // These will all cause assert()s on unconvertible types.
+  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
+      operandType.cast<UnrankedMemRefType>());
   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
   Type llvmElementType = typeConverter.convertType(elementType);
   LLVM::LLVMPointerType elementPtrType =
@@ -1298,7 +1318,8 @@ struct MemRefReshapeOpLowering
     // Extract address space and element type.
     auto targetType =
         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
-    unsigned addressSpace = targetType.getMemorySpaceAsInt();
+    unsigned addressSpace =
+        *getTypeConverter()->getMemRefAddressSpace(targetType);
     Type elementType = targetType.getElementType();
 
     // Create the unranked memref descriptor that holds the ranked one. The
@@ -1564,14 +1585,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
     // Field 1: Copy the allocated pointer, used for malloc/free.
     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
+    unsigned sourceMemorySpace =
+        *getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
     Value bitcastPtr;
     if (getTypeConverter()->useOpaquePointers())
       bitcastPtr = allocatedPtr;
     else
       bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-          loc,
-          LLVM::LLVMPointerType::get(targetElementTy,
-                                     srcMemRefType.getMemorySpaceAsInt()),
+          loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
           allocatedPtr);
 
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -1587,9 +1608,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
       bitcastPtr = alignedPtr;
     } else {
       bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-          loc,
-          LLVM::LLVMPointerType::get(targetElementTy,
-                                     srcMemRefType.getMemorySpaceAsInt()),
+          loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
           alignedPtr);
     }
 

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 296683b1a50dd..922fd7c12d461 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -572,16 +572,24 @@ struct NVGPUAsyncCopyLowering
     Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
                                         adaptor.getDstIndices(), rewriter);
     auto i8Ty = IntegerType::get(op.getContext(), 8);
-    auto dstPointerType =
-        LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
+    FailureOr<unsigned> dstAddressSpace =
+        getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
+    if (failed(dstAddressSpace))
+      return rewriter.notifyMatchFailure(
+          loc, "destination memref address space not convertible to integer");
+    auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, *dstAddressSpace);
     dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
 
     auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
+    FailureOr<unsigned> srcAddressSpace =
+        getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
+    if (failed(srcAddressSpace))
+      return rewriter.notifyMatchFailure(
+          loc, "source memref address space not convertible to integer");
 
     Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
                                         adaptor.getSrcIndices(), rewriter);
-    auto srcPointerType =
-        LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
+    auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, *srcAddressSpace);
     scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
     // Intrinsics takes a global pointer so we need an address space cast.
     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index cdd8cd77aa9c0..509a03c1096a5 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -742,6 +742,8 @@ convertTransferReadToLoads(vector::TransferReadOp op,
   if (failed(warpMatrixInfo))
     return failure();
 
+  Attribute memorySpace =
+      op.getSource().getType().cast<MemRefType>().getMemorySpace();
   bool isLdMatrixCompatible =
       isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 35def299a73f1..4a193f932084f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -92,21 +93,25 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
 }
 
 // Check if the last stride is non-unit or the memory space is not zero.
-static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
+static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
+                                           LLVMTypeConverter &converter) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
-  if (failed(successStrides) || strides.back() != 1 ||
-      memRefType.getMemorySpaceAsInt() != 0)
+  FailureOr<unsigned> addressSpace =
+      converter.getMemRefAddressSpace(memRefType);
+  if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) ||
+      *addressSpace != 0)
     return failure();
   return success();
 }
 
 // Add an index vector component to a base pointer.
 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
+                            LLVMTypeConverter &typeConverter,
                             MemRefType memRefType, Value llvmMemref, Value base,
                             Value index, uint64_t vLen) {
-  assert(succeeded(isMemRefTypeSupported(memRefType)) &&
+  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
          "unsupported memref type");
   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
@@ -116,8 +121,10 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
 // Casts a strided element pointer to a vector pointer.  The vector pointer
 // will be in the same address space as the incoming memref type.
 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
-                         Value ptr, MemRefType memRefType, Type vt) {
-  auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
+                         Value ptr, MemRefType memRefType, Type vt,
+                         LLVMTypeConverter &converter) {
+  unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
+  auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
 }
 
@@ -245,7 +252,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
                      .template cast<VectorType>();
     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
                                                adaptor.getIndices(), rewriter);
-    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
+    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
+                            *this->getTypeConverter());
 
     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
     return success();
@@ -264,7 +272,7 @@ class VectorGatherOpConversion
     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
     assert(memRefType && "The base should be bufferized");
 
-    if (failed(isMemRefTypeSupported(memRefType)))
+    if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
     auto loc = gather->getLoc();
@@ -283,8 +291,8 @@ class VectorGatherOpConversion
     if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
       auto vType = gather.getVectorType();
       // Resolve address.
-      Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
-                                  adaptor.getIndexVec(),
+      Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
+                                  memRefType, base, ptr, adaptor.getIndexVec(),
                                   /*vLen=*/vType.getDimSize(0));
       // Replace with the gather intrinsic.
       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -293,11 +301,14 @@ class VectorGatherOpConversion
       return success();
     }
 
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter](
-                        Type llvm1DVectorTy, ValueRange vectorOperands) {
+    LLVMTypeConverter &typeConverter = *this->getTypeConverter();
+    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
+                     &typeConverter](Type llvm1DVectorTy,
+                                     ValueRange vectorOperands) {
       // Resolve address.
       Value ptrs = getIndexedPtrs(
-          rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
+          rewriter, loc, typeConverter, memRefType, base, ptr,
+          /*index=*/vectorOperands[0],
           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
       // Create the gather intrinsic.
       return rewriter.create<LLVM::masked_gather>(
@@ -323,7 +334,7 @@ class VectorScatterOpConversion
     auto loc = scatter->getLoc();
     MemRefType memRefType = scatter.getMemRefType();
 
-    if (failed(isMemRefTypeSupported(memRefType)))
+    if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
     // Resolve alignment.
@@ -335,9 +346,9 @@ class VectorScatterOpConversion
     VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Value ptrs =
-        getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
-                       adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
+    Value ptrs = getIndexedPtrs(
+        rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
+        ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index ee73672947368..57bdfd57582c8 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -48,7 +49,16 @@ void AMDGPUDialect::initialize() {
 template <typename T>
 static LogicalResult verifyRawBufferOp(T &op) {
   MemRefType bufferType = op.getMemref().getType().template cast<MemRefType>();
-  if (bufferType.getMemorySpaceAsInt() != 0)
+  Attribute memorySpace = bufferType.getMemorySpace();
+  bool isGlobal = false;
+  if (!memorySpace)
+    isGlobal = true;
+  else if (auto intMemorySpace = memorySpace.dyn_cast<IntegerAttr>())
+    isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+  else if (auto gpuMemorySpace = memorySpace.dyn_cast<gpu::AddressSpaceAttr>())
+    isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+
+  if (!isGlobal)
     return op.emitOpError(
         "Buffer ops must operate on a memref in global memory");
   if (!bufferType.hasRank())

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 5dde478898d40..eaa850a14ac7e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,8 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  # Needed for GPU address space enum definition
+  MLIRGPUOps
   MLIRIR
   MLIRSideEffectInterfaces
   )

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index a38695878c103..94f3ab505f23a 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -52,7 +52,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/SerializeToBlob.cpp
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
-  Transforms/LowerMemorySpaceAttributes.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU

diff  --git a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
deleted file mode 100644
index 1d292c5fa45ea..0000000000000
--- a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
+++ /dev/null
@@ -1,179 +0,0 @@
-//===- LowerMemorySpaceAttributes.cpp  ------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// Implementation of a pass that rewrites the IR so that uses of
-/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
-/// with caller-specified numeric values.
-///
-//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
-
-namespace mlir {
-#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
-#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::gpu;
-
-//===----------------------------------------------------------------------===//
-// Conversion Target
-//===----------------------------------------------------------------------===//
-
-/// Returns true if the given `type` is considered as legal during memory space
-/// attribute lowering.
-static bool isLegalType(Type type) {
-  if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
-    return !memRefType.getMemorySpace()
-                .isa_and_nonnull<gpu::AddressSpaceAttr>();
-  }
-  return true;
-}
-
-/// Returns true if the given `attr` is considered legal during memory space
-/// attribute lowering.
-static bool isLegalAttr(Attribute attr) {
-  if (auto typeAttr = attr.dyn_cast<TypeAttr>())
-    return isLegalType(typeAttr.getValue());
-  return true;
-}
-
-/// Returns true if the given `op` is legal during memory space attribute
-/// lowering.
-static bool isLegalOp(Operation *op) {
-  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-    return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
-           llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
-           llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
-                        isLegalType);
-  }
-
-  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
-    return attr.getValue();
-  });
-
-  return llvm::all_of(op->getOperandTypes(), isLegalType) &&
-         llvm::all_of(op->getResultTypes(), isLegalType) &&
-         llvm::all_of(attrs, isLegalAttr);
-}
-
-void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
-  target.markUnknownOpDynamicallyLegal(isLegalOp);
-}
-
-//===----------------------------------------------------------------------===//
-// Type Converter
-//===----------------------------------------------------------------------===//
-
-IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
-  return IntegerAttr::get(IntegerType::get(ctx, 64), space);
-}
-
-void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
-    TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
-  typeConverter.addConversion([mapping](Type type) {
-    return type.replace([mapping](Attribute attr) -> std::optional<Attribute> {
-      auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
-      if (!memorySpaceAttr)
-        return std::nullopt;
-      auto newValue = wrapNumericMemorySpace(
-          attr.getContext(), mapping(memorySpaceAttr.getValue()));
-      return newValue;
-    });
-  });
-}
-
-namespace {
-
-/// Converts any op that has operands/results/attributes with numeric MemRef
-/// memory spaces.
-struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
-  LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
-      : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    SmallVector<NamedAttribute> newAttrs;
-    newAttrs.reserve(op->getAttrs().size());
-    for (auto attr : op->getAttrs()) {
-      if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
-        auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
-        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
-      } else {
-        newAttrs.push_back(attr);
-      }
-    }
-
-    SmallVector<Type> newResults;
-    (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
-
-    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                         newResults, newAttrs, op->getSuccessors());
-
-    for (Region &region : op->getRegions()) {
-      Region *newRegion = state.addRegion();
-      rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
-      (void)getTypeConverter()->convertSignatureArgs(
-          newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
-    }
-
-    Operation *newOp = rewriter.create(state);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-};
-} // namespace
-
-void mlir::gpu::populateMemorySpaceLoweringPatterns(
-    TypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
-                                               typeConverter);
-}
-
-namespace {
-class LowerMemorySpaceAttributesPass
-    : public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
-          LowerMemorySpaceAttributesPass> {
-public:
-  using Base::Base;
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-    Operation *op = getOperation();
-
-    ConversionTarget target(getContext());
-    populateLowerMemorySpaceOpLegality(target);
-
-    TypeConverter typeConverter;
-    typeConverter.addConversion([](Type t) { return t; });
-    populateMemorySpaceAttributeTypeConversions(
-        typeConverter, [this](AddressSpace space) -> unsigned {
-          switch (space) {
-          case AddressSpace::Global:
-            return globalAddrSpace;
-          case AddressSpace::Workgroup:
-            return workgroupAddrSpace;
-          case AddressSpace::Private:
-            return privateAddrSpace;
-          }
-          llvm_unreachable("unknown address space enum value");
-          return 0;
-        });
-    RewritePatternSet patterns(context);
-    populateMemorySpaceLoweringPatterns(typeConverter, patterns);
-    if (failed(applyFullConversion(op, target, std::move(patterns))))
-      return signalPassFailure();
-  }
-};
-} // namespace

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index abdf0a9533b56..d0641ed0eccb6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3053,6 +3053,54 @@ auto TypeConverter::convertBlockSignature(Block *block)
   return conversion;
 }
 
+//===----------------------------------------------------------------------===//
+// Type attribute conversion
+//===----------------------------------------------------------------------===//
+TypeConverter::AttributeConversionResult
+TypeConverter::AttributeConversionResult::result(Attribute attr) {
+  return AttributeConversionResult(attr, resultTag);
+}
+
+TypeConverter::AttributeConversionResult
+TypeConverter::AttributeConversionResult::na() {
+  return AttributeConversionResult(nullptr, naTag);
+}
+
+TypeConverter::AttributeConversionResult
+TypeConverter::AttributeConversionResult::abort() {
+  return AttributeConversionResult(nullptr, abortTag);
+}
+
+bool TypeConverter::AttributeConversionResult::hasResult() const {
+  return impl.getInt() == resultTag;
+}
+
+bool TypeConverter::AttributeConversionResult::isNa() const {
+  return impl.getInt() == naTag;
+}
+
+bool TypeConverter::AttributeConversionResult::isAbort() const {
+  return impl.getInt() == abortTag;
+}
+
+Attribute TypeConverter::AttributeConversionResult::getResult() const {
+  assert(hasResult() && "Cannot get result from N/A or abort");
+  return impl.getPointer();
+}
+
+Optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
+                                                        Attribute attr) {
+  for (TypeAttributeConversionCallbackFn &fn :
+       llvm::reverse(typeAttributeConversions)) {
+    AttributeConversionResult res = fn(type, attr);
+    if (res.hasResult())
+      return res.getResult();
+    if (res.isAbort())
+      return std::nullopt;
+  }
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // FunctionOpInterfaceSignatureConversion
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir b/mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir
new file mode 100644
index 0000000000000..3c1924e179645
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=CHECK,ROCDL
+// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=CHECK,NVVM
+
+gpu.module @kernel {
+  gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
+    %c0 = arith.constant 0 : index
+    memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
+    gpu.return
+  }
+}
+
+// CHECK-LABEL:  llvm.func @private
+//      CHECK:  llvm.store
+// ROCDL-SAME:   : !llvm.ptr<f32, 5>
+//  NVVM-SAME:   : !llvm.ptr<f32>
+
+
+// -----
+
+gpu.module @kernel {
+  gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
+    %c0 = arith.constant 0 : index
+    memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
+    gpu.return
+  }
+}
+
+// CHECK-LABEL:  llvm.func @workgroup
+//       CHECK:  llvm.store
+//  CHECK-SAME:   : !llvm.ptr<f32, 3>
+
+// -----
+
+gpu.module @kernel {
+  gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) -> f32 {
+    %c0 = arith.constant 0 : index
+    %inner = memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
+    %value = memref.load %inner[%c0] : memref<4xf32, #gpu.address_space<global>>
+    gpu.return %value : f32
+  }
+}
+
+// CHECK-LABEL:  llvm.func @nested_memref
+//       CHECK:  llvm.load
+//  CHECK-SAME:   : !llvm.ptr<{{.*}}, 1>
+//       CHECK: [[value:%.+]] = llvm.load
+//  CHECK-SAME:   : !llvm.ptr<f32, 1>
+//       CHECK: llvm.return [[value]]

diff  --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
new file mode 100644
index 0000000000000..786e6f2562199
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm -split-input-file 2>&1 | FileCheck %s
+// Since the error is at an unknown location, we use FileCheck instead of
+// -veri-y-diagnostics here
+
+// CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions.
+// CHECK-LABEL: @bad_address_space
+func.func @bad_address_space(%a: memref<2xindex, "foo">) {
+    %c0 = arith.constant 0 : index
+    // CHECK: memref.store
+    memref.store %c0, %a[%c0] : memref<2xindex, "foo">
+    return
+}
+
+// -----

diff  --git a/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir b/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir
deleted file mode 100644
index 9b4f1dee597b0..0000000000000
--- a/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir
+++ /dev/null
@@ -1,55 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes | FileCheck %s
-// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes="private=0 global=0" | FileCheck %s --check-prefix=CUDA
-
-gpu.module @kernel {
-  gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
-    %c0 = arith.constant 0 : index
-    memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
-    gpu.return
-  }
-}
-
-//      CHECK:  gpu.func @private
-// CHECK-SAME:    private(%{{.+}}: memref<4xf32, 5>)
-//      CHECK:  memref.store
-// CHECK-SAME:   : memref<4xf32, 5>
-
-//      CUDA:  gpu.func @private
-// CUDA-SAME:    private(%{{.+}}: memref<4xf32>)
-//      CUDA:  memref.store
-// CUDA-SAME:   : memref<4xf32>
-
-// -----
-
-gpu.module @kernel {
-  gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
-    %c0 = arith.constant 0 : index
-    memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
-    gpu.return
-  }
-}
-
-//      CHECK:  gpu.func @workgroup
-// CHECK-SAME:    workgroup(%{{.+}}: memref<4xf32, 3>)
-//      CHECK:  memref.store
-// CHECK-SAME:   : memref<4xf32, 3>
-
-// -----
-
-gpu.module @kernel {
-  gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) {
-    %c0 = arith.constant 0 : index
-    memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
-    gpu.return
-  }
-}
-
-//      CHECK:  gpu.func @nested_memref
-// CHECK-SAME:    (%{{.+}}: memref<4xmemref<4xf32, 1>, 1>)
-//      CHECK:  memref.load
-// CHECK-SAME:   : memref<4xmemref<4xf32, 1>, 1>
-
-//      CUDA:  gpu.func @nested_memref
-// CUDA-SAME:    (%{{.+}}: memref<4xmemref<4xf32>>)
-//      CUDA:  memref.load
-// CUDA-SAME:   : memref<4xmemref<4xf32>>


        


More information about the Mlir-commits mailing list