[Mlir-commits] [mlir] 7fb9bbe - [mlir][Memref] Add memref.memory_space_cast and its lowerings

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Feb 9 13:45:03 PST 2023


Author: Krzysztof Drewniak
Date: 2023-02-09T21:44:57Z
New Revision: 7fb9bbe5f0c850ae9480e7a35b7e92e721c26039

URL: https://github.com/llvm/llvm-project/commit/7fb9bbe5f0c850ae9480e7a35b7e92e721c26039
DIFF: https://github.com/llvm/llvm-project/commit/7fb9bbe5f0c850ae9480e7a35b7e92e721c26039.diff

LOG: [mlir][Memref] Add memref.memory_space_cast and its lowerings

Address space casts are present in common MLIR targets (LLVM, SPIRV).
Some planned rewrites (such as one of the potential fixes to the fact
that the AMDGPU backend requires alloca() to live in address space 5 /
the GPU private memory space) may require such casts to be inserted
into MLIR code, where those address spaces could be represented by
arbitrary memory space attributes.

Therefore, we define memref.memory_space_cast and its lowerings.

Depends on D141293

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D141148

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index a68e0879444db..76ed89563fd7b 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -218,6 +218,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
                             LLVM::LLVMPointerType elemPtrType,
                             Value alignedPtr);
 
+  /// Builds IR for getting the pointer to the offset's location.
+  /// Returns a pointer to a convertType(index), which points to the beggining
+  /// of a struct {index, index[rank], index[rank]}.
+  static Value offsetBasePtr(OpBuilder &builder, Location loc,
+                             LLVMTypeConverter &typeConverter,
+                             Value memRefDescPtr,
+                             LLVM::LLVMPointerType elemPtrType);
   /// Builds IR extracting the offset from the descriptor.
   static Value offset(OpBuilder &builder, Location loc,
                       LLVMTypeConverter &typeConverter, Value memRefDescPtr,

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ea4dc1ac2f218..fdc070fdd068e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1177,6 +1177,54 @@ def LoadOp : MemRef_Op<"load",
   let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
 }
 
+//===----------------------------------------------------------------------===//
+// MemorySpaceCastOp
+//===----------------------------------------------------------------------===//
+def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
+      DeclareOpInterfaceMethods<CastOpInterface>,
+      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      MemRefsNormalizable,
+      Pure,
+      SameOperandsAndResultElementType,
+      SameOperandsAndResultShape,
+      ViewLikeOpInterface
+    ]> {
+  let summary = "memref memory space cast operation";
+  let description = [{
+    This operation casts memref values between memory spaces.
+    The input and result will be memrefs of the same types and shape that alias
+    the same underlying memory, though, for some casts on some targets,
+    the underlying values of the pointer stored in the memref may be affected
+    by the cast.
+
+    The input and result must have the same shape, element type, rank, and layout.
+
+    If the source and target address spaces are the same, this operation is a noop.
+
+    Example:
+
+    ```mlir
+    // Cast a GPU private memory attribution into a generic pointer
+    %2 = memref.memory_space_cast %1 : memref<?xf32, 5> to memref<?xf32>
+    // Cast a generic pointer to workgroup-local memory
+    %4 = memref.memory_space_cast %3 : memref<5x4xi32> to memref<5x34xi32, 3>
+    // Cast between two non-default memory spaces
+    %6 = memref.memory_space_cast %5
+      : memref<*xmemref<?xf32>, 5> to memref<*xmemref<?xf32>, 3>
+    ```
+  }];
+
+  let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
+  let results = (outs AnyRankedOrUnrankedMemRef:$dest);
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+
+  let extraClassDeclaration = [{
+    Value getViewSource() { return getSource(); }
+  }];
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // PrefetchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 17259e48e3a3a..12f63cba579f4 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -10,6 +10,7 @@
 #include "MemRefDescriptor.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Support/MathExtras.h"
 
@@ -457,10 +458,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
   builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
 }
 
-Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
-                                       LLVMTypeConverter &typeConverter,
-                                       Value memRefDescPtr,
-                                       LLVM::LLVMPointerType elemPtrType) {
+Value UnrankedMemRefDescriptor::offsetBasePtr(
+    OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+    Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
   auto [elementPtrPtr, elemPtrPtrType] =
       castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
 
@@ -473,9 +473,16 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
         loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()),
         offsetGep);
   }
+  return offsetGep;
+}
 
-  return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(),
-                                      offsetGep);
+Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
+                                       LLVMTypeConverter &typeConverter,
+                                       Value memRefDescPtr,
+                                       LLVM::LLVMPointerType elemPtrType) {
+  Value offsetPtr =
+      offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
+  return builder.create<LLVM::LoadOp>(loc, offsetPtr);
 }
 
 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
@@ -483,20 +490,9 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
                                          Value memRefDescPtr,
                                          LLVM::LLVMPointerType elemPtrType,
                                          Value offset) {
-  auto [elementPtrPtr, elemPtrPtrType] =
-      castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
-
-  Value offsetGep =
-      builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
-                                  elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
-
-  if (!elemPtrType.isOpaque()) {
-    offsetGep = builder.create<LLVM::BitcastOp>(
-        loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()),
-        offsetGep);
-  }
-
-  builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
+  Value offsetPtr =
+      offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
+  builder.create<LLVM::StoreOp>(loc, offset, offsetPtr);
 }
 
 Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 5dccb9b5f9ea9..700304d56df86 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -17,10 +17,12 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include <optional>
 
@@ -1096,6 +1098,118 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
   }
 };
 
+struct MemorySpaceCastOpLowering
+    : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
+  using ConvertOpToLLVMPattern<
+      memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    Type resultType = op.getDest().getType();
+    if (auto resultTypeR = resultType.dyn_cast<MemRefType>()) {
+      auto resultDescType =
+          typeConverter->convertType(resultTypeR).cast<LLVM::LLVMStructType>();
+      Type newPtrType = resultDescType.getBody()[0];
+
+      SmallVector<Value> descVals;
+      MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
+                               descVals);
+      descVals[0] =
+          rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
+      descVals[1] =
+          rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
+      Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
+                                            resultTypeR, descVals);
+      rewriter.replaceOp(op, result);
+      return success();
+    }
+    if (auto resultTypeU = resultType.dyn_cast<UnrankedMemRefType>()) {
+      // Since the type converter won't be doing this for us, get the address
+      // space.
+      auto sourceType = op.getSource().getType().cast<UnrankedMemRefType>();
+      FailureOr<unsigned> maybeSourceAddrSpace =
+          getTypeConverter()->getMemRefAddressSpace(sourceType);
+      if (failed(maybeSourceAddrSpace))
+        return rewriter.notifyMatchFailure(loc,
+                                           "non-integer source address space");
+      unsigned sourceAddrSpace = *maybeSourceAddrSpace;
+      FailureOr<unsigned> maybeResultAddrSpace =
+          getTypeConverter()->getMemRefAddressSpace(resultTypeU);
+      if (failed(maybeResultAddrSpace))
+        return rewriter.notifyMatchFailure(loc,
+                                           "non-integer result address space");
+      unsigned resultAddrSpace = *maybeResultAddrSpace;
+
+      UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
+      Value rank = sourceDesc.rank(rewriter, loc);
+      Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
+
+      // Create and allocate storage for new memref descriptor.
+      auto result = UnrankedMemRefDescriptor::undef(
+          rewriter, loc, typeConverter->convertType(resultTypeU));
+      result.setRank(rewriter, loc, rank);
+      SmallVector<Value, 1> sizes;
+      UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
+                                             result, resultAddrSpace, sizes);
+      Value resultUnderlyingSize = sizes.front();
+      Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
+          loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
+      result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
+
+      // Copy pointers, performing address space casts.
+      Type llvmElementType =
+          typeConverter->convertType(sourceType.getElementType());
+      LLVM::LLVMPointerType sourceElemPtrType =
+          getTypeConverter()->getPointerType(llvmElementType, sourceAddrSpace);
+      auto resultElemPtrType =
+          getTypeConverter()->getPointerType(llvmElementType, resultAddrSpace);
+
+      Value allocatedPtr = sourceDesc.allocatedPtr(
+          rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
+      Value alignedPtr =
+          sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
+                                sourceUnderlyingDesc, sourceElemPtrType);
+      allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+          loc, resultElemPtrType, allocatedPtr);
+      alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+          loc, resultElemPtrType, alignedPtr);
+
+      result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
+                             resultElemPtrType, allocatedPtr);
+      result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
+                           resultUnderlyingDesc, resultElemPtrType, alignedPtr);
+
+      // Copy all the index-valued operands.
+      Value sourceIndexVals =
+          sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
+                                   sourceUnderlyingDesc, sourceElemPtrType);
+      Value resultIndexVals =
+          result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
+                               resultUnderlyingDesc, resultElemPtrType);
+
+      int64_t bytesToSkip =
+          2 *
+          ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
+      Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
+          loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
+      Value copySize = rewriter.create<LLVM::SubOp>(
+          loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
+      Type llvmBool = typeConverter->convertType(rewriter.getI1Type());
+      Value nonVolatile = rewriter.create<LLVM::ConstantOp>(
+          loc, llvmBool, rewriter.getBoolAttr(false));
+      rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
+                                      copySize, nonVolatile);
+
+      rewriter.replaceOp(op, ValueRange{result});
+      return success();
+    }
+    return rewriter.notifyMatchFailure(loc, "unexpected memref type");
+  }
+};
+
 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
 /// memref type. In unranked case, the fields are extracted from the underlying
 /// ranked descriptor.
@@ -1785,6 +1899,7 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
       LoadOpLowering,
       MemRefCastOpLowering,
       MemRefCopyOpLowering,
+      MemorySpaceCastOpLowering,
       MemRefReinterpretCastOpLowering,
       MemRefReshapeOpLowering,
       PrefetchOpLowering,

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index c6ce3f11a6737..5a37806034018 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -223,6 +223,17 @@ class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
+class MemorySpaceCastOpPattern final
+    : public OpConversionPattern<memref::MemorySpaceCastOp> {
+public:
+  using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts memref.store to spirv.Store.
 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
 public:
@@ -552,6 +563,74 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MemorySpaceCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
+    memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = addrCastOp.getLoc();
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+  if (!typeConverter.allows(spirv::Capability::Kernel))
+    return rewriter.notifyMatchFailure(
+        loc, "address space casts require kernel capability");
+
+  auto sourceType = addrCastOp.getSource().getType().dyn_cast<MemRefType>();
+  if (!sourceType)
+    return rewriter.notifyMatchFailure(
+        loc, "SPIR-V lowering requires ranked memref types");
+  auto resultType = addrCastOp.getResult().getType().cast<MemRefType>();
+
+  auto sourceStorageClassAttr =
+      sourceType.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+  if (!sourceStorageClassAttr)
+    return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
+      diag << "source address space " << sourceType.getMemorySpace()
+           << " must be a SPIR-V storage class";
+    });
+  auto resultStorageClassAttr =
+      resultType.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+  if (!resultStorageClassAttr)
+    return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
+      diag << "result address space " << resultType.getMemorySpace()
+           << " must be a SPIR-V storage class";
+    });
+
+  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
+  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
+
+  Value result = adaptor.getSource();
+  Type resultPtrType = typeConverter.convertType(resultType);
+  Type genericPtrType = resultPtrType;
+  // SPIR-V doesn't have a general address space cast operation. Instead, it has
+  // conversions to and from generic pointers. To implement the general case,
+  // we use specific-to-generic conversions when the source class is not
+  // generic. Then when the result storage class is not generic, we convert the
+  // generic pointer (either the input on ar intermediate result) to theat
+  // class. This also means that we'll need the intermediate generic pointer
+  // type if neither the source or destination have it.
+  if (sourceSc != spirv::StorageClass::Generic &&
+      resultSc != spirv::StorageClass::Generic) {
+    Type intermediateType =
+        MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
+                        sourceType.getLayout(),
+                        rewriter.getAttr<spirv::StorageClassAttr>(
+                            spirv::StorageClass::Generic));
+    genericPtrType = typeConverter.convertType(intermediateType);
+  }
+  if (sourceSc != spirv::StorageClass::Generic) {
+    result =
+        rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
+  }
+  if (resultSc != spirv::StorageClass::Generic) {
+    result =
+        rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
+  }
+  rewriter.replaceOp(addrCastOp, result);
+  return success();
+}
+
 LogicalResult
 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
@@ -577,9 +656,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns
-      .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
-           IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
-          typeConverter, patterns.getContext());
+  patterns.add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern,
+               IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+               MemorySpaceCastOpPattern, StoreOpPattern>(typeConverter,
+                                                         patterns.getContext());
 }
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d3a1ae1663c01..02f8019996cdd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1684,6 +1684,50 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+//===----------------------------------------------------------------------===//
+// MemorySpaceCastOp
+//===----------------------------------------------------------------------===//
+
+void MemorySpaceCastOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "memspacecast");
+}
+
+bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  Type a = inputs.front(), b = outputs.front();
+  auto aT = a.dyn_cast<MemRefType>();
+  auto bT = b.dyn_cast<MemRefType>();
+
+  auto uaT = a.dyn_cast<UnrankedMemRefType>();
+  auto ubT = b.dyn_cast<UnrankedMemRefType>();
+
+  if (aT && bT) {
+    if (aT.getElementType() != bT.getElementType())
+      return false;
+    if (aT.getLayout() != bT.getLayout())
+      return false;
+    if (aT.getShape() != bT.getShape())
+      return false;
+    return true;
+  }
+  if (uaT && ubT) {
+    return uaT.getElementType() == ubT.getElementType();
+  }
+  return false;
+}
+
+OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
+  // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
+  // t2)
+  if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
+    getSourceMutable().assign(parentCast.getSource());
+    return getResult();
+  }
+  return Value{};
+}
+
 //===----------------------------------------------------------------------===//
 // PrefetchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 9624b18f30dde..7b9c00cd6ca9b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -258,6 +258,121 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val :
 
 // -----
 
+// FIXME: the *ToLLVM passes don't use information from data layouts
+// to set address spaces, so the constants below don't reflect the layout
+// Update this test once that data layout attribute works how we'd expect it to.
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!llvm.ptr, dense<[64, 64, 64]> : vector<3xi32>>,
+  #dlti.dl_entry<!llvm.ptr<1>, dense<[32, 32, 32]> : vector<3xi32>>> }  {
+  // CHECK-LABEL: @memref_memory_space_cast
+  func.func @memref_memory_space_cast(%input : memref<*xf32>) -> memref<*xf32, 1> {
+    %cast = memref.memory_space_cast %input : memref<*xf32> to memref<*xf32, 1>
+    return %cast : memref<*xf32, 1>
+  }
+}
+// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}} to !llvm.struct<(i64, ptr)>
+// CHECK: [[RANK:%.*]] = llvm.extractvalue [[INPUT]][0] : !llvm.struct<(i64, ptr)>
+// CHECK: [[SOURCE_DESC:%.*]] = llvm.extractvalue [[INPUT]][1]
+// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[RANK]], [[RESULT_0]][0] : !llvm.struct<(i64, ptr)>
+
+// Compute size in bytes to allocate result ranked descriptor
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]]
+// CHECK: [[DOUBLE_RANK:%.*]] = llvm.mul [[C2]], %{{.*}}
+// CHECK: [[NUM_INDEX_VALS:%.*]] = llvm.add [[DOUBLE_RANK]], [[C1]]
+// CHECK: [[INDEX_VALS_SIZE:%.*]] = llvm.mul [[NUM_INDEX_VALS]], [[INDEX_SIZE]]
+// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], [[INDEX_VALS_SIZE]]
+// CHECK: [[RESULT_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8
+// CHECK: llvm.insertvalue [[RESULT_DESC]], [[RESULT_1]][1]
+
+// Cast pointers
+// CHECK: [[SOURCE_ALLOC:%.*]] = llvm.load [[SOURCE_DESC]]
+// CHECK: [[SOURCE_ALIGN_GEP:%.*]] = llvm.getelementptr [[SOURCE_DESC]][1]
+// CHECK: [[SOURCE_ALIGN:%.*]] = llvm.load [[SOURCE_ALIGN_GEP]] : !llvm.ptr
+// CHECK: [[RESULT_ALLOC:%.*]] = llvm.addrspacecast [[SOURCE_ALLOC]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK: [[RESULT_ALIGN:%.*]] = llvm.addrspacecast [[SOURCE_ALIGN]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK: llvm.store [[RESULT_ALLOC]], [[RESULT_DESC]] : !llvm.ptr
+// CHECK: [[RESULT_ALIGN_GEP:%.*]] = llvm.getelementptr [[RESULT_DESC]][1]
+// CHECK: llvm.store [[RESULT_ALIGN]], [[RESULT_ALIGN_GEP]] : !llvm.ptr
+
+// Memcpy remaniing values
+
+// CHECK: [[SOURCE_OFFSET_GEP:%.*]] = llvm.getelementptr [[SOURCE_DESC]][2]
+// CHECK: [[RESULT_OFFSET_GEP:%.*]] = llvm.getelementptr [[RESULT_DESC]][2]
+// CHECK: [[SIZEOF_TWO_RESULT_PTRS:%.*]] = llvm.mlir.constant(16 : index) : i64
+// CHECK: [[COPY_SIZE:%.*]] = llvm.sub [[DESC_ALLOC_SIZE]], [[SIZEOF_TWO_RESULT_PTRS]]
+// CHECK: [[FALSE:%.*]] = llvm.mlir.constant(false) : i1
+// CHECK: "llvm.intr.memcpy"([[RESULT_OFFSET_GEP]], [[SOURCE_OFFSET_GEP]], [[COPY_SIZE]], [[FALSE]])
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_static_to_dynamic
+func.func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %static : memref<10x42xf32> to memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_static_to_mixed
+func.func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %static : memref<10x42xf32> to memref<?x42xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_dynamic_to_static
+func.func @memref_cast_dynamic_to_static(%dynamic : memref<?x?xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_dynamic_to_mixed
+func.func @memref_cast_dynamic_to_mixed(%dynamic : memref<?x?xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_mixed_to_dynamic
+func.func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %mixed : memref<42x?xf32> to memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_mixed_to_static
+func.func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %mixed : memref<42x?xf32> to memref<42x1xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_mixed_to_mixed
+func.func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
+// CHECK-NOT: llvm.bitcast
+  %0 = memref.cast %mixed : memref<42x?xf32> to memref<?x1xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_cast_ranked_to_unranked
 // CHECK32-LABEL: func @memref_cast_ranked_to_unranked
 func.func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index d9552972e2d90..24877cc299a14 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -411,3 +411,27 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{
   %out = memref.realloc %in {alignment = 8} : memref<2xf32> to memref<4xf32>
   return %out : memref<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @memref_memory_space_cast
+func.func @memref_memory_space_cast(%input : memref<?xf32>) -> memref<?xf32, 1> {
+  %cast = memref.memory_space_cast %input : memref<?xf32> to memref<?xf32, 1>
+  return %cast : memref<?xf32, 1>
+}
+// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}}
+//  CHECK: [[ALLOC:%.*]] = llvm.extractvalue [[INPUT]][0]
+//  CHECK: [[ALIGN:%.*]] = llvm.extractvalue [[INPUT]][1]
+// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2]
+//   CHECK: [[SIZE:%.*]] = llvm.extractvalue [[INPUT]][3, 0]
+// CHECK: [[STRIDE:%.*]] = llvm.extractvalue [[INPUT]][4, 0]
+// CHECK: [[CAST_ALLOC:%.*]] = llvm.addrspacecast [[ALLOC]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK: [[CAST_ALIGN:%.*]] = llvm.addrspacecast [[ALIGN]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef
+// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[CAST_ALLOC]], [[RESULT_0]][0]
+// CHECK: [[RESULT_2:%.*]] = llvm.insertvalue [[CAST_ALIGN]], [[RESULT_1]][1]
+// CHECK: [[RESULT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[RESULT_2]][2]
+// CHECK: [[RESULT_4:%.*]] = llvm.insertvalue [[SIZE]], [[RESULT_3]][3, 0]
+// CHECK: [[RESULT_5:%.*]] = llvm.insertvalue [[STRIDE]], [[RESULT_4]][4, 0]
+// CHECK: [[RESULT:%.*]] = builtin.unrealized_conversion_cast [[RESULT_5]] : {{.*}} to memref<?xf32, 1>
+// CHECK: return [[RESULT]]

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 67556f00c2f21..a31c24657f8bc 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -212,6 +212,32 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
 
 // -----
 
+// Check address space casts
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0,
+      [
+        Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func.func @memory_space_cast
+func.func @memory_space_cast(%arg: memref<4xf32, #spirv.storage_class<CrossWorkgroup>>)
+    -> memref<4xf32, #spirv.storage_class<Function>> {
+  // CHECK: %[[ARG_CAST:.+]] = builtin.unrealized_conversion_cast {{.*}} to !spirv.ptr<!spirv.array<4 x f32>, CrossWorkgroup>
+  // CHECK: %[[TO_GENERIC:.+]] = spirv.PtrCastToGeneric %[[ARG_CAST]] : !spirv.ptr<!spirv.array<4 x f32>, CrossWorkgroup> to !spirv.ptr<!spirv.array<4 x f32>, Generic>
+  // CHECK: %[[TO_PRIVATE:.+]] = spirv.GenericCastToPtr %[[TO_GENERIC]] : !spirv.ptr<!spirv.array<4 x f32>, Generic> to !spirv.ptr<!spirv.array<4 x f32>, Function>
+  // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[TO_PRIVATE]]
+  // CHECK: return %[[RET]]
+  %ret = memref.memory_space_cast %arg : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>
+    to memref<4xf32, #spirv.storage_class<Function>>
+  return %ret : memref<4xf32, #spirv.storage_class<Function>>
+}
+
+} // end module
+
+// -----
+
 // Check that access chain indices are properly adjusted if non-32-bit types are
 // emulated via 32-bit types.
 // TODO: Test i64 types.

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 14c570759b4c1..4295947226433 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -906,3 +906,25 @@ func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0
   memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_trivial_memory_space_cast(
+//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32>
+//       CHECK:   return %[[arg]]
+func.func @fold_trivial_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32> {
+  %0 = memref.memory_space_cast %arg : memref<?xf32> to memref<?xf32>
+  return %0 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_multiple_memory_space_cast(
+//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32>
+//       CHECK:   %[[res:.*]] = memref.memory_space_cast %[[arg]] : memref<?xf32> to memref<?xf32, 2>
+//       CHECK:   return %[[res]]
+func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32, 2> {
+  %0 = memref.memory_space_cast %arg : memref<?xf32> to memref<?xf32, 1>
+  %1 = memref.memory_space_cast %0 : memref<?xf32, 1> to memref<?xf32, 2>
+  return %1 : memref<?xf32, 2>
+}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index fbff761565fab..79de75117dea0 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -380,3 +380,9 @@ func.func @memref_extract_aligned_pointer(%src : memref<?xf32>) -> index {
   %0 = memref.extract_aligned_pointer_as_index %src : memref<?xf32> -> index
   return %0 : index
 }
+
+// CHECK-LABEL: func @memref_memory_space_cast
+func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
+  %dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
+  return %dst : memref<?xf32, 1>
+}


        


More information about the Mlir-commits mailing list