[Mlir-commits] [mlir] 3c9aa10 - Foo

Nicolas Vasilache llvmlistbot at llvm.org
Fri Aug 4 04:25:18 PDT 2023


Author: Nicolas Vasilache
Date: 2023-08-04T11:06:17Z
New Revision: 3c9aa10c57cf0833ff108ecf9ffbb512bd96cc89

URL: https://github.com/llvm/llvm-project/commit/3c9aa10c57cf0833ff108ecf9ffbb512bd96cc89
DIFF: https://github.com/llvm/llvm-project/commit/3c9aa10c57cf0833ff108ecf9ffbb512bd96cc89.diff

LOG: Foo

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
    mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
    mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
index bc45fcd84e0982..b6e10b29fbb20c 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H
 #define MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/IR/DataLayout.h"
 
 namespace mlir {
@@ -66,6 +67,9 @@ class LowerToLLVMOptions {
   /// Get the index bitwidth.
   unsigned getIndexBitwidth() const { return indexBitwidth; }
 
+  /// Hook to customize the conversion of MemRefType to LLVMType.
+  llvm::function_ref<Type(MemRefType)> memrefIndexTypeConverter = nullptr;
+
 private:
   unsigned indexBitwidth;
 };

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 0aee13818df4d5..52bba8b0e97ef5 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -50,6 +50,11 @@ class ConvertToLLVMPattern : public ConversionPattern {
   /// defined by the used type converter.
   Type getIndexType() const;
 
+  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
+  /// defined by the used type converter and matching the index type needed for
+  /// MemRefType `t`.
+  Type getIndexTypeMatchingMemRef(MemRefType t) const;
+
   /// Gets the MLIR type wrapping the LLVM integer type whose bit width
   /// corresponds to that of a LLVM pointer type.
   Type getIntPtrType(unsigned addressSpace = 0) const;

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 79a68e875f045e..970fde18ad7bad 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -132,14 +132,20 @@ class LLVMTypeConverter : public TypeConverter {
   /// integer type with the size configured for this type converter.
   Type getIndexType();
 
-  /// Returns true if using opaque pointers was enabled in the lowering options.
+  /// Gets the LLVM representation of the index type that matches the MemRefType
+  /// `t`. The returned type is an integer type with the size configured for
+  /// this type converter.
+  Type getIndexTypeMatchingMemRef(MemRefType t);
+
+  /// Returns true if using opaque pointers was enabled in the lowering
+  /// options.
   bool useOpaquePointers() const { return getOptions().useOpaquePointers; }
 
   /// Creates an LLVM pointer type with the given element type and address
   /// space.
-  /// This function is meant to be used in code supporting both typed and opaque
-  /// pointers, as it will create an opaque pointer with the given address space
-  /// if opaque pointers are enabled in the lowering options.
+  /// This function is meant to be used in code supporting both typed and
+  /// opaque pointers, as it will create an opaque pointer with the given
+  /// address space if opaque pointers are enabled in the lowering options.
   LLVM::LLVMPointerType getPointerType(Type elementType,
                                        unsigned addressSpace = 0);
 
@@ -170,13 +176,13 @@ class LLVMTypeConverter : public TypeConverter {
 
 private:
   /// Convert a function type.  The arguments and results are converted one by
-  /// one.  Additionally, if the function returns more than one value, pack the
-  /// results into an LLVM IR structure type so that the converted function type
-  /// returns at most one result.
+  /// one.  Additionally, if the function returns more than one value, pack
+  /// the results into an LLVM IR structure type so that the converted
+  /// function type returns at most one result.
   Type convertFunctionType(FunctionType type);
 
-  /// Convert the index type.  Uses llvmModule data layout to create an integer
-  /// of the pointer bitwidth.
+  /// Convert the index type.  Uses llvmModule data layout to create an
+  /// integer of the pointer bitwidth.
   Type convertIndexType(IndexType type);
 
   /// Convert an integer type `i*` to `!llvm<"i*">`.
@@ -184,12 +190,13 @@ class LLVMTypeConverter : public TypeConverter {
 
   /// Convert a floating point type: `f16` to `f16`, `f32` to
   /// `f32` and `f64` to `f64`.  `bf16` is not supported
-  /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how
-  /// all LLVM backends that support them currently represent them.
+  /// by LLVM. 8-bit float types are converted to 8-bit integers as this is
+  /// how all LLVM backends that support them currently represent them.
   Type convertFloatType(FloatType type);
 
-  /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
-  /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
+  /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half
+  /// }">`, `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>`
+  /// to
   /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
   Type convertComplexType(ComplexType type);
 
@@ -197,10 +204,10 @@ class LLVMTypeConverter : public TypeConverter {
   Type convertMemRefType(MemRefType type);
 
   /// Convert a memref type into a list of LLVM IR types that will form the
-  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
-  /// arrays in the descriptors are unpacked to individual index-typed elements,
-  /// else they are are kept as rank-sized arrays of index type. In particular,
-  /// the list will contain:
+  /// memref descriptor. If `unpackAggregates` is true the `sizes` and
+  /// `strides` arrays in the descriptors are unpacked to individual
+  /// index-typed elements, else they are are kept as rank-sized arrays of
+  /// index type. In particular, the list will contain:
   /// - two pointers to the memref element type, followed by
   /// - an index-typed offset, followed by
   /// - (if unpackAggregates = true)
@@ -220,9 +227,9 @@ class LLVMTypeConverter : public TypeConverter {
   SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
                                                  bool unpackAggregates);
 
-  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
-  /// that will form the unranked memref descriptor. In particular, this list
-  /// contains:
+  /// Convert an unranked memref type into a list of non-aggregate LLVM IR
+  /// types that will form the unranked memref descriptor. In particular, this
+  /// list contains:
   /// - an integer rank, followed by
   /// - a pointer to the memref descriptor struct.
   /// For example, memref<*xf32> is converted to the following list:

diff  --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
index 495c4d63986f80..b35b77ccee6371 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
@@ -43,7 +43,8 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
     MemRefType memRefType = op.getType();
     Value alignment;
     if (auto alignmentAttr = op.getAlignment()) {
-      Type indexType = getIndexType();
+      Type indexType =
+          ConvertToLLVMPattern::getIndexTypeMatchingMemRef(memRefType);
       alignment =
           createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 842e63aff7bd9f..f779c0cb4592f4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,14 @@ def GPU_Dialect : Dialect {
     /// Returns the numeric value used to identify the private memory address
     /// space.
     static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+
+    /// Return true if the given MemRefType has an address space that is a 
+    /// gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has matches is a gpu::AddressSpaceAttr
+    /// attribute with value 'workgroup`.
+    static bool isSharedMemoryAddressSpace(Attribute type);
   }];
 
   let dependentDialects = ["arith::ArithDialect"];

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 9993c093badc16..0a197b9ab01004 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -67,7 +67,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
 protected:
   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
                        MemRefType type, MemRefDescriptor desc) const {
-    Type indexType = ConvertToLLVMPattern::getIndexType();
+    Type indexType = ConvertToLLVMPattern::getIndexTypeMatchingMemRef(type);
     return type.hasStaticShape()
                ? ConvertToLLVMPattern::createIndexAttrConstant(
                      rewriter, loc, indexType, type.getNumElements())
@@ -654,10 +654,16 @@ class ConvertSDDMMOpToGpuRuntimeCallPattern
 
 } // namespace
 
+static IntegerType getIndexTypeForMemRef(MemRefType t) {
+  int64_t numBits = gpu::GPUDialect::hasSharedMemoryAddressSpace(t) ? 32 : 64;
+  return IntegerType::get(t.getContext(), numBits);
+}
+
 void GpuToLLVMConversionPass::runOnOperation() {
   LowerToLLVMOptions options(&getContext());
   options.useOpaquePointers = useOpaquePointers;
   options.useBarePtrCallConv = hostBarePtrCallConv;
+  options.memrefIndexTypeConverter = getIndexTypeForMemRef;
 
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 1699172eb9dab3..72c6c93734fd10 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -19,6 +19,20 @@ using namespace mlir;
 // ConvertToLLVMPattern
 //===----------------------------------------------------------------------===//
 
+static Value convertToDesiredIndexType(OpBuilder &b, Location loc, Value src,
+                                       Type desiredIndexType) {
+  assert(src.getType().isIntOrIndex() && !src.getType().isIndex() &&
+         "expected int type");
+  assert(desiredIndexType.isIntOrIndex() && !desiredIndexType.isIndex() &&
+         "expected int type");
+  if (src.getType() == desiredIndexType)
+    return src;
+  if (src.getType().getIntOrFloatBitWidth() <
+      desiredIndexType.getIntOrFloatBitWidth())
+    return b.create<LLVM::SExtOp>(loc, desiredIndexType, src);
+  return b.create<LLVM::TruncOp>(loc, desiredIndexType, src);
+}
+
 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
                                            LLVMTypeConverter &typeConverter,
@@ -38,6 +52,10 @@ Type ConvertToLLVMPattern::getIndexType() const {
   return getTypeConverter()->getIndexType();
 }
 
+Type ConvertToLLVMPattern::getIndexTypeMatchingMemRef(MemRefType t) const {
+  return getTypeConverter()->getIndexTypeMatchingMemRef(t);
+}
+
 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
   return IntegerType::get(&getTypeConverter()->getContext(),
                           getTypeConverter()->getPointerBitwidth(addressSpace));
@@ -74,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   Value base =
       memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
 
-  Type indexType = getIndexType();
+  Type indexType = getIndexTypeMatchingMemRef(type);
   Value index;
   for (int i = 0, e = indices.size(); i < e; ++i) {
     Value increment = indices[i];
@@ -83,8 +101,11 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
           ShapedType::isDynamic(strides[i])
               ? memRefDescriptor.stride(rewriter, loc, i)
               : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
+      increment =
+          convertToDesiredIndexType(rewriter, loc, increment, indexType);
       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
     }
+    increment = convertToDesiredIndexType(rewriter, loc, increment, indexType);
     index =
         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
   }
@@ -127,7 +148,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
 
   sizes.reserve(memRefType.getRank());
   unsigned dynamicIndex = 0;
-  Type indexType = getIndexType();
+  Type indexType = getIndexTypeMatchingMemRef(memRefType);
   for (int64_t size : memRefType.getShape()) {
     sizes.push_back(
         size == ShapedType::kDynamic
@@ -194,7 +215,7 @@ Value ConvertToLLVMPattern::getNumElements(
              static_cast<ssize_t>(dynamicSizes.size()) &&
          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
 
-  Type indexType = getIndexType();
+  Type indexType = getIndexTypeMatchingMemRef(memRefType);
   Value numElements = memRefType.getRank() == 0
                           ? createIndexAttrConstant(rewriter, loc, indexType, 1)
                           : nullptr;
@@ -233,7 +254,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
 
   // Field 3: Offset in aligned pointer.
-  Type indexType = getIndexType();
+  Type indexType = getIndexTypeMatchingMemRef(memRefType);
   memRefDescriptor.setOffset(
       rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
 

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 9e03e2ffbacf83..c63be57dd850ff 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -174,6 +174,11 @@ Type LLVMTypeConverter::getIndexType() {
   return IntegerType::get(&getContext(), getIndexTypeBitwidth());
 }
 
+Type LLVMTypeConverter::getIndexTypeMatchingMemRef(MemRefType t) {
+  return options.memrefIndexTypeConverter ? options.memrefIndexTypeConverter(t)
+                                          : getIndexType();
+}
+
 LLVM::LLVMPointerType
 LLVMTypeConverter::getPointerType(Type elementType, unsigned int addressSpace) {
   if (useOpaquePointers())
@@ -339,7 +344,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
   }
   auto ptrTy = getPointerType(elementType, *addressSpace);
 
-  auto indexTy = getIndexType();
+  Type indexTy = getIndexTypeMatchingMemRef(type);
 
   SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
   auto rank = type.getRank();
@@ -358,7 +363,8 @@ unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
   // Compute the descriptor size given that of its components indicated above.
   unsigned space = *getMemRefAddressSpace(type);
   return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
-         (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
+         (1 + 2 * type.getRank()) *
+             layout.getTypeSize(getIndexTypeMatchingMemRef(type));
 }
 
 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 55372b9c9c1248..1c2cd86c53706f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -160,7 +160,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
     auto computeNumElements =
         [&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
       // Compute number of elements.
-      Type indexType = ConvertToLLVMPattern::getIndexType();
+      Type indexType = ConvertToLLVMPattern::getIndexTypeMatchingMemRef(type);
       Value numElements =
           type.isDynamicDim(0)
               ? getDynamicSize()
@@ -483,7 +483,8 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
     // The size value that we have to extract can be obtained using GEPop with
     // `dimOp.index() + 1` index argument.
     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
-        loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
+        loc,
+        createIndexAttrConstant(rewriter, loc, adaptor.getIndex().getType(), 1),
         adaptor.getIndex());
     Value sizePtr = rewriter.create<LLVM::GEPOp>(
         loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
@@ -510,7 +511,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
 
     // Take advantage if index is constant.
     MemRefType memRefType = cast<MemRefType>(operandType);
-    Type indexType = getIndexType();
+    Type indexType = getIndexTypeMatchingMemRef(memRefType);
     if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
       int64_t i = *index;
       if (i >= 0 && i < memRefType.getRank()) {
@@ -1360,7 +1361,7 @@ struct MemRefReshapeOpLowering
       assert(targetMemRefType.getLayout().isIdentity() &&
              "Identity layout map is a precondition of a valid reshape op");
 
-      Type indexType = getIndexType();
+      Type indexType = getIndexTypeMatchingMemRef(targetMemRefType);
       Value stride = nullptr;
       int64_t targetRank = targetMemRefType.getRank();
       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
@@ -1455,7 +1456,8 @@ struct MemRefReshapeOpLowering
     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
-    Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
+    Value oneIndex =
+        createIndexAttrConstant(rewriter, loc, resultRank.getType(), 1);
     Value resultRankMinusOne =
         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
 
@@ -1708,7 +1710,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
 
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
-    Type indexType = getIndexType();
+    auto indexType = targetMemRef.getIndexType();
     // Field 3: The offset in the resulting type must be 0. This is
     // because of the type change: an offset on srcType* may not be
     // expressible as an offset on dstType*.
@@ -1865,7 +1867,7 @@ class ExtractStridedMetadataOpLowering
 
 } // namespace
 
-static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
+void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   // clang-format off
   patterns.add<
@@ -1881,6 +1883,7 @@ static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
       GetGlobalMemrefOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,
+      MemRefCopyOpLowering,
       MemorySpaceCastOpLowering,
       MemRefReinterpretCastOpLowering,
       MemRefReshapeOpLowering,
@@ -1893,15 +1896,6 @@ static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
       TransposeOpLowering,
       ViewOpLowering>(converter);
   // clang-format on
-}
-
-void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  // clang-format off
-  patterns.add<
-      MemRefCopyOpLowering>(converter);
-  // clang-format on
-
   auto allocLowering = converter.getOptions().allocLowering;
   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
     patterns.add<AlignedAllocOpLowering, AlignedReallocOpLowering,
@@ -1909,9 +1903,6 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
     patterns.add<AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
         converter);
-
-  populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(converter,
-                                                                  patterns);
 }
 
 namespace {
@@ -1940,12 +1931,6 @@ struct FinalizeMemRefToLLVMConversionPass
                                     &dataLayoutAnalysis);
     RewritePatternSet patterns(&getContext());
     populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
-    if (isa<ModuleOp>(getOperation())) {
-      populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
-    } else {
-      populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
-          typeConverter, patterns);
-    }
     LLVMConversionTarget target(getContext());
     target.addLegalOp<func::FuncOp>();
     if (failed(applyPartialConversion(op, target, std::move(patterns))))

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c9f378c181e36d..cec2830e10b932 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -35,6 +35,22 @@ using namespace mlir::gpu;
 
 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
 
+/// Return true if the given MemRefType has an address space that is a
+/// gpu::AddressSpaceAttr attribute with value 'workgroup`.
+bool gpu::GPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+  return isSharedMemoryAddressSpace(type.getMemorySpace());
+}
+
+/// Return true if the given Attribute has matches is a gpu::AddressSpaceAttr
+/// attribute with value 'workgroup`.
+bool gpu::GPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // GPU Device Mapping Attributes
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list