[Mlir-commits] [mlir] c59ce1f - [mlir] support memref of memref in standard-to-llvm conversion

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 8 02:11:43 PDT 2021


Author: Alex Zinenko
Date: 2021-06-08T11:11:31+02:00
New Revision: c59ce1f6257c88330c1f1757c36d59d34fe29248

URL: https://github.com/llvm/llvm-project/commit/c59ce1f6257c88330c1f1757c36d59d34fe29248
DIFF: https://github.com/llvm/llvm-project/commit/c59ce1f6257c88330c1f1757c36d59d34fe29248.diff

LOG: [mlir] support memref of memref in standard-to-llvm conversion

Now that memref supports arbitrary element types, add support for memref of
memref and make sure it is properly converted to the LLVM dialect. The type
support itself avoids adding the interface to the memref type itself similarly
to other built-in types. This allows the shape, and therefore byte size, of the
memref descriptor to remain a lowering aspect that is easier to customize and
evolve as opposed to sanctifying it in the data layout specification for the
memref type itself.

Factor out the code previously in a testing pass to live in a dedicated data
layout analysis and use that analysis in the conversion to compute the
allocation size for memref of memref. Other conversions will be ported
separately.

Depends On D103827

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D103828

Added: 
    mlir/include/mlir/Analysis/DataLayoutAnalysis.h
    mlir/lib/Analysis/DataLayoutAnalysis.cpp

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Analysis/CMakeLists.txt
    mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
    mlir/test/IR/parser.mlir
    mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataLayoutAnalysis.h b/mlir/include/mlir/Analysis/DataLayoutAnalysis.h
new file mode 100644
index 0000000000000..c190c8006b14d
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataLayoutAnalysis.h
@@ -0,0 +1,48 @@
+//===- DataLayoutAnalysis.h - API for Querying Nested Data Layout -*- C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATALAYOUTANALYSIS_H
+#define MLIR_ANALYSIS_DATALAYOUTANALYSIS_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+
+#include <memory>
+
+namespace mlir {
+
+class Operation;
+class DataLayout;
+
+/// Stores data layout objects for each operation that specifies the data layout
+/// above and below the given operation.
+class DataLayoutAnalysis {
+public:
+  /// Constructs the data layouts.
+  explicit DataLayoutAnalysis(Operation *root);
+
+  /// Returns the data layout active active at the given operation, that is the
+  /// data layout specified by the closest ancestor that can specify one, or the
+  /// default layout if there is no such ancestor.
+  const DataLayout &getAbove(Operation *operation) const;
+
+  /// Returns the data layout specified by the given operation or its closest
+  /// ancestor that can specify one.
+  const DataLayout &getAtOrAbove(Operation *operation) const;
+
+private:
+  /// Storage for individual data layouts.
+  DenseMap<Operation *, std::unique_ptr<DataLayout>> layouts;
+
+  /// Default data layout in case no operations specify one.
+  std::unique_ptr<DataLayout> defaultLayout;
+};
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_DATALAYOUTANALYSIS_H

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 07b2df58ea1bb..58e9c73562879 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -29,6 +29,7 @@ namespace mlir {
 
 class BaseMemRefType;
 class ComplexType;
+class DataLayoutAnalysis;
 class LLVMTypeConverter;
 class UnrankedMemRefType;
 
@@ -62,10 +63,14 @@ class LLVMTypeConverter : public TypeConverter {
   using TypeConverter::convertType;
 
   /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
-  LLVMTypeConverter(MLIRContext *ctx);
+  /// Optionally takes a data layout analysis to use in conversions.
+  LLVMTypeConverter(MLIRContext *ctx,
+                    const DataLayoutAnalysis *analysis = nullptr);
 
-  /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
-  LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
+  /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally
+  /// takes a data layout analysis to use in conversions.
+  LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
+                    const DataLayoutAnalysis *analysis = nullptr);
 
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
@@ -124,6 +129,11 @@ class LLVMTypeConverter : public TypeConverter {
   /// Returns the data layout to use during and after conversion.
   const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
 
+  /// Returns the data layout analysis to query during conversion.
+  const DataLayoutAnalysis *getDataLayoutAnalysis() const {
+    return dataLayoutAnalysis;
+  }
+
   /// Gets the LLVM representation of the index type. The returned type is an
   /// integer type with the size configured for this type converter.
   Type getIndexType();
@@ -134,6 +144,13 @@ class LLVMTypeConverter : public TypeConverter {
   /// Gets the pointer bitwidth.
   unsigned getPointerBitwidth(unsigned addressSpace = 0);
 
+  /// Returns the size of the memref descriptor object in bytes.
+  unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout);
+
+  /// Returns the size of the unranked memref descriptor object in bytes.
+  unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
+                                           const DataLayout &layout);
+
 protected:
   /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
@@ -207,11 +224,14 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a memref type to a bare pointer to the memref element type.
   Type convertMemRefToBarePtr(BaseMemRefType type);
 
-  // Convert a 1D vector type into an LLVM vector type.
+  /// Convert a 1D vector type into an LLVM vector type.
   Type convertVectorType(VectorType type);
 
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
+
+  /// Data layout analysis mapping scopes to layouts active in them.
+  const DataLayoutAnalysis *dataLayoutAnalysis;
 };
 
 /// Helper class to produce LLVM dialect operations extracting or inserting
@@ -634,11 +654,6 @@ struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
     return op->getResult(0).getType().cast<MemRefType>();
   }
 
-  LogicalResult match(Operation *op) const override {
-    MemRefType memRefType = getMemRefResultType(op);
-    return success(isConvertibleAndHasIdentityMaps(memRefType));
-  }
-
   // An `alloc` is converted into a definition of a memref descriptor value and
   // a call to `malloc` to allocate the underlying data buffer.  The memref
   // descriptor is of the LLVM structure type where:
@@ -655,8 +670,9 @@ struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
 
   // An `alloca` is converted into a definition of a memref descriptor value and
   // an llvm.alloca to allocate the underlying data buffer.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 namespace LLVM {

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 25be9f9d7c93c..f96add7d698e1 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -32,7 +32,7 @@ static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
 class LowerToLLVMOptions {
 public:
   explicit LowerToLLVMOptions(MLIRContext *ctx);
-  explicit LowerToLLVMOptions(MLIRContext *ctx, DataLayout dl);
+  explicit LowerToLLVMOptions(MLIRContext *ctx, const DataLayout &dl);
 
   bool useBarePtrCallConv = false;
   bool emitCWrappers = false;

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d858c3129091b..89745db764c2e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -272,7 +272,8 @@ inline bool BaseMemRefType::classof(Type type) {
 }
 
 inline bool BaseMemRefType::isValidElementType(Type type) {
-  return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>() ||
+  return type.isIntOrIndexOrFloat() ||
+         type.isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>() ||
          type.isa<MemRefElementTypeInterface>();
 }
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 85787afc49547..b142f6e4865d1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -313,6 +313,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
     - built-in index type;
     - built-in floating point types;
     - built-in vector types with elements of the above types;
+    - another memref type;
     - any other type implementing `MemRefElementTypeInterface`.
 
     ##### Codegen of Unranked Memref

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index c996794b5c681..0a2ca03877610 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
   BufferViewFlowAnalysis.cpp
   CallGraph.cpp
   DataFlowAnalysis.cpp
+  DataLayoutAnalysis.cpp
   LinearTransform.cpp
   Liveness.cpp
   LoopAnalysis.cpp
@@ -22,6 +23,7 @@ add_mlir_library(MLIRAnalysis
   BufferViewFlowAnalysis.cpp
   CallGraph.cpp
   DataFlowAnalysis.cpp
+  DataLayoutAnalysis.cpp
   Liveness.cpp
   NumberOfExecutions.cpp
   SliceAnalysis.cpp

diff  --git a/mlir/lib/Analysis/DataLayoutAnalysis.cpp b/mlir/lib/Analysis/DataLayoutAnalysis.cpp
new file mode 100644
index 0000000000000..71b6461857406
--- /dev/null
+++ b/mlir/lib/Analysis/DataLayoutAnalysis.cpp
@@ -0,0 +1,51 @@
+//===- DataLayoutAnalysis.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+using namespace mlir;
+
+DataLayoutAnalysis::DataLayoutAnalysis(Operation *root)
+    : defaultLayout(std::make_unique<DataLayout>(DataLayoutOpInterface())) {
+  // Construct a DataLayout if possible from the op.
+  auto computeLayout = [this](Operation *op) {
+    if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
+      layouts[op] = std::make_unique<DataLayout>(iface);
+    if (auto module = dyn_cast<ModuleOp>(op))
+      layouts[op] = std::make_unique<DataLayout>(module);
+  };
+
+  // Compute layouts for both ancestors and descendants.
+  root->walk(computeLayout);
+  for (Operation *ancestor = root->getParentOp(); ancestor != nullptr;
+       ancestor = ancestor->getParentOp()) {
+    computeLayout(ancestor);
+  }
+}
+
+const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const {
+  for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr;
+       ancestor = ancestor->getParentOp()) {
+    auto it = layouts.find(ancestor);
+    if (it != layouts.end())
+      return *it->getSecond();
+  }
+
+  // Fallback to the default layout.
+  return *defaultLayout;
+}
+
+const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const {
+  auto it = layouts.find(operation);
+  if (it != layouts.end())
+    return *it->getSecond();
+  return getAbove(operation);
+}

diff  --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
index 13cf5cb16c9f6..8741eb08c33f6 100644
--- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRAnalysis
   MLIRDataLayoutInterfaces
   MLIRLLVMIR
   MLIRMath

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 11d0cd6fdc766..3ee6b31d08f53 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "../PassDetail.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -101,14 +102,16 @@ LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
 }
 
 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
-LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
-    : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx)) {}
+LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
+                                     const DataLayoutAnalysis *analysis)
+    : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
 
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
-                                     const LowerToLLVMOptions &options)
-    : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
-      options(options) {
+                                     const LowerToLLVMOptions &options,
+                                     const DataLayoutAnalysis *analysis)
+    : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
+      dataLayoutAnalysis(analysis) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
 
   // Register conversions for the builtin types.
@@ -342,6 +345,14 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
   return results;
 }
 
+unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
+                                                    const DataLayout &layout) {
+  // Compute the descriptor size given that of its components indicated above.
+  unsigned space = type.getMemorySpaceAsInt();
+  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
+         (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
+}
+
 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
@@ -369,6 +380,15 @@ SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
           LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))};
 }
 
+unsigned
+LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
+                                                   const DataLayout &layout) {
+  // Compute the descriptor size given that of its components indicated above.
+  unsigned space = type.getMemorySpaceAsInt();
+  return layout.getTypeSize(getIndexType()) +
+         llvm::divideCeil(getPointerBitwidth(space), 8);
+}
+
 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
   if (!convertType(type.getElementType()))
     return {};
@@ -1900,26 +1920,30 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
                                 converter) {}
 
-  /// Returns the memref's element size in bytes.
+  /// Returns the memref's element size in bytes using the data layout active at
+  /// `op`.
   // TODO: there are other places where this is used. Expose publicly?
-  static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
-    auto elementType = memRefType.getElementType();
-
-    unsigned sizeInBits;
-    if (elementType.isIntOrFloat()) {
-      sizeInBits = elementType.getIntOrFloatBitWidth();
-    } else {
-      auto vectorType = elementType.cast<VectorType>();
-      sizeInBits =
-          vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+  unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
+    const DataLayout *layout = &defaultLayout;
+    if (const DataLayoutAnalysis *analysis =
+            getTypeConverter()->getDataLayoutAnalysis()) {
+      layout = &analysis->getAbove(op);
     }
-    return llvm::divideCeil(sizeInBits, 8);
+    Type elementType = memRefType.getElementType();
+    if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
+      return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
+                                                         *layout);
+    if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
+      return getTypeConverter()->getUnrankedMemRefDescriptorSize(
+          memRefElementType, *layout);
+    return layout->getTypeSize(elementType);
   }
 
   /// Returns true if the memref size in bytes is known to be a multiple of
-  /// factor.
-  static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) {
-    uint64_t sizeDivisor = getMemRefEltSizeInBytes(type);
+  /// factor assuming the data layout active at `op`.
+  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
+                              Operation *op) const {
+    uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
       if (type.isDynamic(type.getDimSize(i)))
         continue;
@@ -1938,7 +1962,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
     // Whenever we don't have alignment set, we will use an alignment
     // consistent with the element type; since the allocation size has to be a
     // power of two, we will bump to the next power of two if it already isn't.
-    auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType());
+    auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
     return std::max(kMinAlignedAllocAlignment,
                     llvm::PowerOf2Ceil(eltSizeBytes));
   }
@@ -1954,7 +1978,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
 
     // aligned_alloc requires size to be a multiple of alignment; we will pad
     // the size to the next multiple if necessary.
-    if (!isMemRefSizeMultipleOf(memRefType, alignment))
+    if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
 
     Type elementPtrType = this->getElementPtrType(memRefType);
@@ -1971,6 +1995,9 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
 
   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
+
+  /// Default layout to use in absence of the corresponding analysis.
+  DataLayout defaultLayout;
 };
 
 // Out of line definition, required till C++17.
@@ -4068,8 +4095,10 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
     }
 
     ModuleOp m = getOperation();
+    const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
 
-    LowerToLLVMOptions options(&getContext(), DataLayout(m));
+    LowerToLLVMOptions options(&getContext(),
+                               dataLayoutAnalysis.getAtOrAbove(m));
     options.useBarePtrCallConv = useBarePtrCallConv;
     options.emitCWrappers = emitCWrappers;
     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
@@ -4078,7 +4107,9 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
                          : LowerToLLVMOptions::AllocLowering::Malloc);
     options.dataLayout = llvm::DataLayout(this->dataLayout);
-    LLVMTypeConverter typeConverter(&getContext(), options);
+
+    LLVMTypeConverter typeConverter(&getContext(), options,
+                                    &dataLayoutAnalysis);
 
     RewritePatternSet patterns(&getContext());
     populateStdToLLVMConversionPatterns(typeConverter, patterns);
@@ -4102,10 +4133,12 @@ Value AllocLikeOpLLVMLowering::createAligned(
   return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
 }
 
-void AllocLikeOpLLVMLowering::rewrite(
+LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
   MemRefType memRefType = getMemRefResultType(op);
+  if (!isConvertibleAndHasIdentityMaps(memRefType))
+    return rewriter.notifyMatchFailure(op, "incompatible memref type");
   auto loc = op->getLoc();
 
   // Get actual sizes of the memref as values: static sizes are constant
@@ -4129,6 +4162,7 @@ void AllocLikeOpLLVMLowering::rewrite(
 
   // Return the final value of the descriptor.
   rewriter.replaceOp(op, {memRefDescriptor});
+  return success();
 }
 
 mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
@@ -4159,6 +4193,6 @@ mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx)
     : LowerToLLVMOptions(ctx, DataLayout()) {}
 
 mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx,
-                                             mlir::DataLayout dl) {
+                                             const DataLayout &dl) {
   indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx));
 }

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index ac189f467748c..e66810c3add57 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
-// RUN: mlir-opt -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC
+// RUN: mlir-opt -split-input-file -convert-std-to-llvm %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC
 
 // CHECK-LABEL: func @check_strided_memref_arguments(
 // CHECK-COUNT-2: !llvm.ptr<f32>
@@ -529,3 +529,98 @@ func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 
 // CHECK: ^bb3:
 // CHECK:   llvm.return
+
+// -----
+
+// ALIGNED-ALLOC-LABEL: @memref_of_memref
+func @memref_of_memref() {
+  // Sizeof computation is as usual.
+  // ALIGNED-ALLOC: %[[NULL:.*]] = llvm.mlir.null
+  // ALIGNED-ALLOC: %[[PTR:.*]] = llvm.getelementptr
+  // ALIGNED-ALLOC: %[[SIZEOF:.*]] = llvm.ptrtoint
+
+  // Static alignment should be computed as ceilPowerOf2(2 * sizeof(pointer) +
+  // (1 + 2 * rank) * sizeof(index) = ceilPowerOf2(2 * 8 + 3 * 8) = 64.
+  // ALIGNED-ALLOC: llvm.mlir.constant(64 : index)
+
+  // Check that the types are converted as expected.
+  // ALIGNED-ALLOC: llvm.call @aligned_alloc
+  // ALIGNED-ALLOC: llvm.bitcast %{{.*}} : !llvm.ptr<i8> to
+  // ALIGNED-ALLOC-SAME: !llvm.
+  // ALIGNED-ALLOC-SAME: [[INNER:ptr<struct<\(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>\)>>]]
+  // ALIGNED-ALLOC: llvm.mlir.undef
+  // ALIGNED-ALLOC-SAME: !llvm.struct<([[INNER]], [[INNER]], i64, array<1 x i64>, array<1 x i64>)>
+  %0 = memref.alloc() : memref<1xmemref<1xf32>>
+  return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
+  // ALIGNED-ALLOC-LABEL: @memref_of_memref_32
+  func @memref_of_memref_32() {
+    // Sizeof computation is as usual.
+    // ALIGNED-ALLOC: %[[NULL:.*]] = llvm.mlir.null
+    // ALIGNED-ALLOC: %[[PTR:.*]] = llvm.getelementptr
+    // ALIGNED-ALLOC: %[[SIZEOF:.*]] = llvm.ptrtoint
+
+    // Static alignment should be computed as ceilPowerOf2(2 * sizeof(pointer) +
+    // (1 + 2 * rank) * sizeof(index) = ceilPowerOf2(2 * 8 + 3 * 4) = 32.
+    // ALIGNED-ALLOC: llvm.mlir.constant(32 : index)
+
+    // Check that the types are converted as expected.
+    // ALIGNED-ALLOC: llvm.call @aligned_alloc
+    // ALIGNED-ALLOC: llvm.bitcast %{{.*}} : !llvm.ptr<i8> to
+    // ALIGNED-ALLOC-SAME: !llvm.
+    // ALIGNED-ALLOC-SAME: [[INNER:ptr<struct<\(ptr<f32>, ptr<f32>, i32, array<1 x i32>, array<1 x i32>\)>>]]
+    // ALIGNED-ALLOC: llvm.mlir.undef
+    // ALIGNED-ALLOC-SAME: !llvm.struct<([[INNER]], [[INNER]], i32, array<1 x i32>, array<1 x i32>)>
+    %0 = memref.alloc() : memref<1xmemref<1xf32>>
+    return
+  }
+}
+
+
+// -----
+
+// ALIGNED-ALLOC-LABEL: @memref_of_memref_of_memref
+func @memref_of_memref_of_memref() {
+  // Sizeof computation is as usual, also check the type.
+  // ALIGNED-ALLOC: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr<
+  // ALIGNED-ALLOC-SAME:   struct<(
+  // ALIGNED-ALLOC-SAME:     [[INNER:ptr<struct<\(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>\)>>]],
+  // ALIGNED-ALLOC-SAME:     [[INNER]],
+  // ALIGNED-ALLOC-SAME:     i64, array<1 x i64>, array<1 x i64>
+  // ALIGNED-ALLOC-SAME:   )>
+  // ALIGNED-ALLOC-SAME: >
+  // ALIGNED-ALLOC: %[[PTR:.*]] = llvm.getelementptr
+  // ALIGNED-ALLOC: %[[SIZEOF:.*]] = llvm.ptrtoint
+
+  // Static alignment should be computed as ceilPowerOf2(2 * sizeof(pointer) +
+  // (1 + 2 * rank) * sizeof(index) = ceilPowerOf2(2 * 8 + 3 * 8) = 64.
+  // ALIGNED-ALLOC: llvm.mlir.constant(64 : index)
+  // ALIGNED-ALLOC: llvm.call @aligned_alloc
+  %0 = memref.alloc() : memref<1 x memref<2 x memref<3 x f32>>>
+  return
+}
+
+// -----
+
+// ALIGNED-ALLOC-LABEL: @ranked_unranked
+func @ranked_unranked() {
+  // ALIGNED-ALLOC: llvm.mlir.null
+  // ALIGNED-ALLOC-SAME: !llvm.[[INNER:ptr<struct<\(i64, ptr<i8>\)>>]]
+  // ALIGNED-ALLOC: llvm.getelementptr
+  // ALIGNED-ALLOC: llvm.ptrtoint
+
+  // Static alignment should be computed as ceilPowerOf2(sizeof(index) +
+  // sizeof(pointer)) = 16.
+  // ALIGNED-ALLOC: llvm.mlir.constant(16 : index)
+  // ALIGNED-ALLOC: llvm.call @aligned_alloc
+  // ALIGNED-ALLOC: llvm.bitcast
+  // ALIGNED-ALLOC-SAME: !llvm.ptr<i8> to !llvm.[[INNER]]
+  %0 = memref.alloc() : memref<1 x memref<* x f32>>
+  memref.cast %0 : memref<1 x memref<* x f32>> to memref<* x memref<* x f32>>
+  return
+}
+

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index 6df3c94943759..aeb222630afba 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -447,3 +447,4 @@ func private @unsupported_memref_element_type() -> memref<42 x !test.memref_elem
 // BAREPTR-SAME: memref<
 // BAREPTR-NOT: !llvm.ptr
 func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element>
+

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 2a3487cffe4c4..373fde0e89c93 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -181,6 +181,18 @@ func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
 // CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
 func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
 
+// CHECK: func private @memref_of_memref(memref<1xmemref<1xf64>>)
+func private @memref_of_memref(memref<1xmemref<1xf64>>)
+
+// CHECK: func private @memref_of_unranked_memref(memref<1xmemref<*xf32>>)
+func private @memref_of_unranked_memref(memref<1xmemref<*xf32>>)
+
+// CHECK: func private @unranked_memref_of_memref(memref<*xmemref<1xf32>>)
+func private @unranked_memref_of_memref(memref<*xmemref<1xf32>>)
+
+// CHECK: func private @unranked_memref_of_unranked_memref(memref<*xmemref<*xi32>>)
+func private @unranked_memref_of_unranked_memref(memref<*xmemref<*xi32>>)
+
 // CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
 func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
 

diff  --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
index 843714c640326..6f31e957660aa 100644
--- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Pass/Pass.h"
@@ -23,28 +24,14 @@ struct TestDataLayoutQuery
   void runOnFunction() override {
     FuncOp func = getFunction();
     Builder builder(func.getContext());
-    DenseMap<Operation *, DataLayout> layouts;
+    const DataLayoutAnalysis &layouts = getAnalysis<DataLayoutAnalysis>();
 
     func.walk([&](test::DataLayoutQueryOp op) {
       // Skip the ops with already processed in a deeper call.
       if (op->getAttr("size"))
         return;
 
-      auto scope = op->getParentOfType<test::OpWithDataLayoutOp>();
-      if (!layouts.count(scope)) {
-        layouts.try_emplace(
-            scope, scope ? cast<DataLayoutOpInterface>(scope.getOperation())
-                         : nullptr);
-      }
-      auto module = op->getParentOfType<ModuleOp>();
-      if (!layouts.count(module))
-        layouts.try_emplace(module, module);
-
-      Operation *closest = (scope && module && module->isProperAncestor(scope))
-                               ? scope.getOperation()
-                               : module.getOperation();
-
-      const DataLayout &layout = layouts.find(closest)->getSecond();
+      const DataLayout &layout = layouts.getAbove(op);
       unsigned size = layout.getTypeSize(op.getType());
       unsigned bitsize = layout.getTypeSizeInBits(op.getType());
       unsigned alignment = layout.getTypeABIAlignment(op.getType());


        


More information about the Mlir-commits mailing list