[Mlir-commits] [mlir] 1916b0e - [mlir] support data layout specs on ModuleOp

Alex Zinenko llvmlistbot at llvm.org
Wed Mar 24 07:13:49 PDT 2021


Author: Alex Zinenko
Date: 2021-03-24T15:13:38+01:00
New Revision: 1916b0e098ad6ddeb746c4720099fb96bff02d31

URL: https://github.com/llvm/llvm-project/commit/1916b0e098ad6ddeb746c4720099fb96bff02d31
DIFF: https://github.com/llvm/llvm-project/commit/1916b0e098ad6ddeb746c4720099fb96bff02d31.diff

LOG: [mlir] support data layout specs on ModuleOp

ModuleOp is a natural place to provide scoped data layout information. However,
it is undesirable for ModuleOp to implement the entirety of
DataLayoutOpInterface because that would require either pushing the interface
inside the IR library instead of a separate library, or putting the default
implementation of the interface as inline functions in headers leading to
binary bloat. Instead, ModuleOp accepts an arbitrary data layout spec attribute
and has a dedicated hook to extract it, and DataLayout is modified to know
about ModuleOp particularities.

Reviewed By: herhut, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D98500

Added: 
    mlir/test/Interfaces/DataLayoutInterfaces/module.mlir

Modified: 
    mlir/docs/DataLayout.md
    mlir/include/mlir/IR/BuiltinOps.h
    mlir/include/mlir/IR/BuiltinOps.td
    mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
    mlir/lib/Dialect/DLTI/DLTI.cpp
    mlir/lib/IR/BuiltinDialect.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/Interfaces/DataLayoutInterfaces.cpp
    mlir/test/Dialect/DLTI/invalid.mlir
    mlir/test/IR/module-op.mlir
    mlir/test/lib/Transforms/TestDataLayoutQuery.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DataLayout.md b/mlir/docs/DataLayout.md
index 4a57a2d6ca4c5..732dbe077285e 100644
--- a/mlir/docs/DataLayout.md
+++ b/mlir/docs/DataLayout.md
@@ -18,24 +18,26 @@ system. At the top level, it consists of:
     types.
 
 Built-in types are handled specially to decrease the overall query cost.
+Similarly, built-in `ModuleOp` supports data layouts without going through the
+interface.
 
 ## Usage
 
 ### Scoping
 
 Following MLIR's nested structure, data layout properties are _scoped_ to
-regions belonging to specific operations that implement the
-`DataLayoutOpInterface`. Such scoping operations partially control the data
-layout properties and may have attributes that affect them, typically organized
-in a data layout specification.
+regions belonging to either operations that implement the
+`DataLayoutOpInterface` or `ModuleOp` operations. Such scoping operations
+partially control the data layout properties and may have attributes that affect
+them, typically organized in a data layout specification.
 
 Types may have a 
diff erent data layout in 
diff erent scopes, including scopes
 that are nested in other scopes such as modules contained in other modules. At
 the same time, within the given scope excluding any nested scope, a given type
 has fixed data layout properties. Types are also expected to have a default,
 "natural" data layout in case they are used outside of any operation that
-provides data layout scope for them. This ensure data layout queries always have
-a valid result.
+provides data layout scope for them. This ensures that data layout queries
+always have a valid result.
 
 ### Compatibility and Transformations
 
@@ -180,20 +182,24 @@ and the compatibility of nested entries.
 
 The overall flow of a data layout property query is as follows.
 
--   The user constructs a `DataLayout` at the given scope. The constructor
+1.  The user constructs a `DataLayout` at the given scope. The constructor
     fetches the data layout specification and combines it with those of
     enclosing scopes (layouts are expected to be compatible).
--   The user calls `DataLayout::query(Type ty)`.
--   If `DataLayout` has a cached response, this response is returned
+2.  The user calls `DataLayout::query(Type ty)`.
+3.  If `DataLayout` has a cached response, this response is returned
     immediately.
--   Otherwise, the query is handed down by `DataLayout` to
-    `DataLayoutOpInterface::query(ty, *this, relevantEntries)` where the
-    relevant entries are computed as described above.
--   Unless the `query` hook is reimplemented by the op interface, the query is
+4.  Otherwise, the query is handed down by `DataLayout` to the closest layout
+    scoping operation. If it implements `DataLayoutOpInterface`, then the query
+    is forwarded to`DataLayoutOpInterface::query(ty, *this, relevantEntries)`
+    where the relevant entries are computed as described above. If it does not
+    implement `DataLayoutOpInterface`, it must be a `ModuleOp`, and the query is
+    forwarded to `DataLayoutTypeInterface::query(dataLayout, relevantEntries)`
+    after casting `ty` to the type interface.
+5.  Unless the `query` hook is reimplemented by the op interface, the query is
     handled further down to `DataLayoutTypeInterface::query(dataLayout,
     relevantEntries)` after casting `ty` to the type interface. If the type does
     not implement the interface, an unrecoverable fatal error is produced.
--   The type is expected to always provide the response, which is returned up
+6.  The type is expected to always provide the response, which is returned up
     the call stack and cached by the `DataLayout.`
 
 ## Default Implementation
@@ -201,6 +207,14 @@ The overall flow of a data layout property query is as follows.
 The default implementation of the data layout interfaces directly handles
 queries for a subset of built-in types.
 
+### Built-in Modules
+
+Built-in `ModuleOp` allows at most one attribute that implements
+`DataLayoutSpecInterface`. It does not implement the entire interface for
+efficiency and layering reasons. Instead, `DataLayout` can be constructed for
+`ModuleOp` and handles modules transparently alongside other operations that
+implement the interface.
+
 ### Built-in Types
 
 The following describes the default properties of built-in types.

diff  --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h
index c0163b148f3cc..cf43b7cd7305d 100644
--- a/mlir/include/mlir/IR/BuiltinOps.h
+++ b/mlir/include/mlir/IR/BuiltinOps.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 

diff  --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index 3e61608ace1d5..4d14b8868e473 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -18,6 +18,7 @@ include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // Base class for Builtin dialect ops.
@@ -198,6 +199,12 @@ def ModuleOp : Builtin_Op<"module", [
 
     /// A ModuleOp may optionally define a symbol.
     bool isOptionalSymbol() { return true; }
+
+    //===------------------------------------------------------------------===//
+    // DataLayoutOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    DataLayoutSpecInterface getDataLayoutSpec();
   }];
   let verifier = [{ return ::verify(*this); }];
 

diff  --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index f92048973e714..99fc718b17333 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -29,6 +29,7 @@ using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
 using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
 class DataLayoutOpInterface;
 class DataLayoutSpecInterface;
+class ModuleOp;
 
 namespace detail {
 /// Default handler for the type size request. Computes results for built-in
@@ -60,10 +61,11 @@ DataLayoutEntryList filterEntriesForType(DataLayoutEntryListRef entries,
 DataLayoutEntryInterface
 filterEntryForIdentifier(DataLayoutEntryListRef entries, Identifier id);
 
-/// Verifies that the operation implementing the data layout interface is valid.
-/// This calls the verifier of the spec attribute and checks if the layout is
-/// compatible with specs attached to the enclosing operations.
-LogicalResult verifyDataLayoutOp(DataLayoutOpInterface op);
+/// Verifies that the operation implementing the data layout interface, or a
+/// module operation, is valid. This calls the verifier of the spec attribute
+/// and checks if the layout is compatible with specs attached to the enclosing
+/// operations.
+LogicalResult verifyDataLayoutOp(Operation *op);
 
 /// Verifies that a data layout spec is valid. This dispatches to individual
 /// entry verifiers, and then to the verifiers implemented by the relevant type
@@ -133,6 +135,7 @@ class DataLayoutDialectInterface
 class DataLayout {
 public:
   explicit DataLayout(DataLayoutOpInterface op);
+  explicit DataLayout(ModuleOp op);
 
   /// Returns the size of the given type in the current scope.
   unsigned getTypeSize(Type t) const;
@@ -159,7 +162,7 @@ class DataLayout {
   /// Operation defining the scope of requests.
   // TODO: this is mutable because the generated interface method are not const.
   // Update the generator to support const methods and change this to const.
-  mutable DataLayoutOpInterface scope;
+  mutable Operation *scope;
 
   /// Caches for individual requests.
   mutable DenseMap<Type, unsigned> sizes;

diff  --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 4b9c3b523b6bf..2567be64ac1ad 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -369,6 +370,8 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
       return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName
                              << "' is expected to be a #dlti.dl_spec attribute";
     }
+    if (isa<ModuleOp>(op))
+      return detail::verifyDataLayoutOp(op);
     return success();
   }
 

diff  --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 28aef1500a00c..1035961f51c1c 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -222,6 +222,17 @@ ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
   return builder.create<ModuleOp>(loc, name);
 }
 
+DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
+  // Take the first and only (if present) attribute that implements the
+  // interface. This needs a linear search, but is called only once per data
+  // layout object construction that is used for repeated queries.
+  for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) {
+    if (auto spec = attr.dyn_cast<DataLayoutSpecInterface>())
+      return spec;
+  }
+  return {};
+}
+
 static LogicalResult verify(ModuleOp op) {
   // Check that none of the attributes are non-dialect attributes, except for
   // the symbol related attributes.
@@ -236,6 +247,23 @@ static LogicalResult verify(ModuleOp op) {
                               << attr.first << "'";
   }
 
+  // Check that there is at most one data layout spec attribute.
+  StringRef layoutSpecAttrName;
+  DataLayoutSpecInterface layoutSpec;
+  for (const NamedAttribute &na : op->getAttrs()) {
+    if (auto spec = na.second.dyn_cast<DataLayoutSpecInterface>()) {
+      if (layoutSpec) {
+        InFlightDiagnostic diag =
+            op.emitOpError() << "expects at most one data layout attribute";
+        diag.attachNote() << "'" << layoutSpecAttrName
+                          << "' is a data layout attribute";
+        diag.attachNote() << "'" << na.first << "' is a data layout attribute";
+      }
+      layoutSpecAttrName = na.first.strref();
+      layoutSpec = spec;
+    }
+  }
+
   return success();
 }
 

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index cef7068af4d8b..68367d69b68a8 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRIR
   MLIRBuiltinTypesIncGen
   MLIRCallInterfacesIncGen
   MLIRCastInterfacesIncGen
+  MLIRDataLayoutInterfacesIncGen
   MLIROpAsmInterfaceIncGen
   MLIRRegionKindInterfaceIncGen
   MLIRSideEffectInterfacesIncGen

diff  --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 3d23aa8859c8d..4c5f45eefeadd 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -8,9 +8,12 @@
 
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -105,65 +108,97 @@ mlir::detail::filterEntryForIdentifier(DataLayoutEntryListRef entries,
   return it == entries.end() ? DataLayoutEntryInterface() : *it;
 }
 
+static DataLayoutSpecInterface getSpec(Operation *operation) {
+  return llvm::TypeSwitch<Operation *, DataLayoutSpecInterface>(operation)
+      .Case<ModuleOp, DataLayoutOpInterface>(
+          [&](auto op) { return op.getDataLayoutSpec(); })
+      .Default([](Operation *) {
+        llvm_unreachable("expected an op with data layout spec");
+        return DataLayoutSpecInterface();
+      });
+}
+
 /// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
-/// implement the `DataLayoutOpInterface`.
-static void findProperAscendantsWithLayout(
-    Operation *leaf, SmallVectorImpl<DataLayoutOpInterface> &opsWithLayout) {
+/// are either modules or implement the `DataLayoutOpInterface`.
+static void
+collectParentLayouts(Operation *leaf,
+                     SmallVectorImpl<DataLayoutSpecInterface> &specs,
+                     SmallVectorImpl<Location> *opLocations = nullptr) {
   if (!leaf)
     return;
 
-  while (auto opLayout = leaf->getParentOfType<DataLayoutOpInterface>()) {
-    opsWithLayout.push_back(opLayout);
-    leaf = opLayout;
+  for (Operation *parent = leaf->getParentOp(); parent != nullptr;
+       parent = parent->getParentOp()) {
+    llvm::TypeSwitch<Operation *>(parent)
+        .Case<ModuleOp>([&](ModuleOp op) {
+          // Skip top-level module op unless it has a layout. Top-level module
+          // without layout is most likely the one implicitly added by the
+          // parser and it doesn't have location. Top-level null specification
+          // would have had the same effect as not having a specification at all
+          // (using type defaults).
+          if (!op->getParentOp() && !op.getDataLayoutSpec())
+            return;
+          specs.push_back(op.getDataLayoutSpec());
+          if (opLocations)
+            opLocations->push_back(op.getLoc());
+        })
+        .Case<DataLayoutOpInterface>([&](DataLayoutOpInterface op) {
+          specs.push_back(op.getDataLayoutSpec());
+          if (opLocations)
+            opLocations->push_back(op.getLoc());
+        });
   }
 }
 
 /// Returns a layout spec that is a combination of the layout specs attached
 /// to the given operation and all its ancestors.
-static DataLayoutSpecInterface
-getCombinedDataLayout(DataLayoutOpInterface leaf) {
+static DataLayoutSpecInterface getCombinedDataLayout(Operation *leaf) {
   if (!leaf)
     return {};
 
+  assert((isa<ModuleOp, DataLayoutOpInterface>(leaf)) &&
+         "expected an op with data layout spec");
+
   SmallVector<DataLayoutOpInterface> opsWithLayout;
-  findProperAscendantsWithLayout(leaf, opsWithLayout);
+  SmallVector<DataLayoutSpecInterface> specs;
+  collectParentLayouts(leaf, specs);
 
   // Fast track if there are no ancestors.
-  if (opsWithLayout.empty())
-    return leaf.getDataLayoutSpec();
+  if (specs.empty())
+    return getSpec(leaf);
 
   // Create the list of non-null specs (null/missing specs can be safely
   // ignored) from the outermost to the innermost.
-  SmallVector<DataLayoutSpecInterface> specs;
-  specs.reserve(opsWithLayout.size());
-  for (DataLayoutOpInterface op : llvm::reverse(opsWithLayout))
-    if (DataLayoutSpecInterface current = op.getDataLayoutSpec())
-      specs.push_back(current);
+  auto nonNullSpecs = llvm::to_vector<2>(llvm::make_filter_range(
+      llvm::reverse(specs),
+      [](DataLayoutSpecInterface iface) { return iface != nullptr; }));
 
   // Combine the specs using the innermost as anchor.
-  if (DataLayoutSpecInterface current = leaf.getDataLayoutSpec())
-    return current.combineWith(specs);
-  if (specs.empty())
+  if (DataLayoutSpecInterface current = getSpec(leaf))
+    return current.combineWith(nonNullSpecs);
+  if (nonNullSpecs.empty())
     return {};
-  return specs.back().combineWith(llvm::makeArrayRef(specs).drop_back());
+  return nonNullSpecs.back().combineWith(
+      llvm::makeArrayRef(nonNullSpecs).drop_back());
 }
 
-LogicalResult mlir::detail::verifyDataLayoutOp(DataLayoutOpInterface op) {
-  DataLayoutSpecInterface spec = op.getDataLayoutSpec();
+LogicalResult mlir::detail::verifyDataLayoutOp(Operation *op) {
+  DataLayoutSpecInterface spec = getSpec(op);
   // The layout specification may be missing and it's fine.
   if (!spec)
     return success();
 
-  if (failed(spec.verifySpec(op.getLoc())))
+  if (failed(spec.verifySpec(op->getLoc())))
     return failure();
   if (!getCombinedDataLayout(op)) {
     InFlightDiagnostic diag =
-        op.emitError()
-        << "data layout is not a refinement of the layouts in enclosing ops";
-    SmallVector<DataLayoutOpInterface> opsWithLayout;
-    findProperAscendantsWithLayout(op, opsWithLayout);
-    for (DataLayoutOpInterface parent : opsWithLayout)
-      diag.attachNote(parent.getLoc()) << "enclosing op with data layout";
+        op->emitError()
+        << "data layout does not combine with layouts of enclosing ops";
+    SmallVector<DataLayoutSpecInterface> specs;
+    SmallVector<Location> opLocations;
+    collectParentLayouts(op, specs, &opLocations);
+    for (Location loc : opLocations)
+      diag.attachNote(loc) << "enclosing op with data layout";
     return diag;
   }
   return success();
@@ -173,33 +208,40 @@ LogicalResult mlir::detail::verifyDataLayoutOp(DataLayoutOpInterface op) {
 // DataLayout
 //===----------------------------------------------------------------------===//
 
-mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
-    : originalLayout(getCombinedDataLayout(op)), scope(op) {
+template <typename OpTy>
+void checkMissingLayout(DataLayoutSpecInterface originalLayout, OpTy op) {
   if (!originalLayout) {
     assert((!op || !op.getDataLayoutSpec()) &&
            "could not compute layout information for an op (failed to "
            "combine attributes?)");
   }
+}
 
+mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
+    : originalLayout(getCombinedDataLayout(op)), scope(op) {
 #ifndef NDEBUG
-  SmallVector<DataLayoutOpInterface> opsWithLayout;
-  findProperAscendantsWithLayout(op, opsWithLayout);
-  layoutStack = llvm::to_vector<2>(
-      llvm::map_range(opsWithLayout, [](DataLayoutOpInterface iface) {
-        return iface.getDataLayoutSpec();
-      }));
+  checkMissingLayout(originalLayout, op);
+  collectParentLayouts(op, layoutStack);
+#endif
+}
+
+mlir::DataLayout::DataLayout(ModuleOp op)
+    : originalLayout(getCombinedDataLayout(op)), scope(op) {
+#ifndef NDEBUG
+  checkMissingLayout(originalLayout, op);
+  collectParentLayouts(op, layoutStack);
 #endif
 }
 
 void mlir::DataLayout::checkValid() const {
 #ifndef NDEBUG
-  SmallVector<DataLayoutOpInterface> opsWithLayout;
-  findProperAscendantsWithLayout(scope, opsWithLayout);
-  assert(opsWithLayout.size() == layoutStack.size() &&
+  SmallVector<DataLayoutSpecInterface> specs;
+  collectParentLayouts(scope, specs);
+  assert(specs.size() == layoutStack.size() &&
          "data layout object used, but no longer valid due to the change in "
          "number of nested layouts");
-  for (auto pair : llvm::zip(opsWithLayout, layoutStack)) {
-    Attribute newLayout = std::get<0>(pair).getDataLayoutSpec();
+  for (auto pair : llvm::zip(specs, layoutStack)) {
+    Attribute newLayout = std::get<0>(pair);
     Attribute origLayout = std::get<1>(pair);
     assert(newLayout == origLayout &&
            "data layout object used, but no longer valid "
@@ -228,30 +270,39 @@ static unsigned cachedLookup(Type t, DenseMap<Type, unsigned> &cache,
 unsigned mlir::DataLayout::getTypeSize(Type t) const {
   checkValid();
   return cachedLookup(t, sizes, [&](Type ty) {
-    return (scope && originalLayout)
-               ? scope.getTypeSize(
-                     ty, *this, originalLayout.getSpecForType(ty.getTypeID()))
-               : detail::getDefaultTypeSize(ty, *this, {});
+    if (originalLayout) {
+      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
+      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
+        return iface.getTypeSize(ty, *this, list);
+      return detail::getDefaultTypeSize(ty, *this, list);
+    }
+    return detail::getDefaultTypeSize(ty, *this, {});
   });
 }
 
 unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const {
   checkValid();
   return cachedLookup(t, abiAlignments, [&](Type ty) {
-    return (scope && originalLayout)
-               ? scope.getTypeABIAlignment(
-                     ty, *this, originalLayout.getSpecForType(ty.getTypeID()))
-               : detail::getDefaultABIAlignment(ty, *this, {});
+    if (originalLayout) {
+      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
+      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
+        return iface.getTypeABIAlignment(ty, *this, list);
+      return detail::getDefaultABIAlignment(ty, *this, list);
+    }
+    return detail::getDefaultABIAlignment(ty, *this, {});
   });
 }
 
 unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const {
   checkValid();
   return cachedLookup(t, preferredAlignments, [&](Type ty) {
-    return (scope && originalLayout)
-               ? scope.getTypePreferredAlignment(
-                     ty, *this, originalLayout.getSpecForType(ty.getTypeID()))
-               : detail::getDefaultPreferredAlignment(ty, *this, {});
+    if (originalLayout) {
+      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
+      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
+        return iface.getTypePreferredAlignment(ty, *this, list);
+      return detail::getDefaultPreferredAlignment(ty, *this, list);
+    }
+    return detail::getDefaultPreferredAlignment(ty, *this, {});
   });
 }
 

diff  --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 9f7ff7e36c376..aa9a713b26b84 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -55,7 +55,7 @@
 
 // Mismatching entries don't combine.
 "test.op_with_data_layout"() ({
-  // expected-error at below {{data layout is not a refinement of the layouts in enclosing ops}}
+  // expected-error at below {{data layout does not combine with layouts of enclosing ops}}
   // expected-note at above {{enclosing op with data layout}}
   "test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>> } : () -> ()
   "test.maybe_terminator_op"() : () -> ()
@@ -71,3 +71,22 @@
 
 // expected-error at below {{data layout specified for a type that does not support it}}
 "test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!test.test_type, 32>> } : () -> ()
+
+// -----
+
+// Mismatching entries are checked on module ops as well.
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 33>>} {
+  // expected-error at below {{data layout does not combine with layouts of enclosing ops}}
+  // expected-note at above {{enclosing op with data layout}}
+  module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>>} {
+  }
+}
+
+// -----
+
+// Mismatching entries are checked on a combination of modules and other ops.
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 33>>} {
+  // expected-error at below {{data layout does not combine with layouts of enclosing ops}}
+  // expected-note at above {{enclosing op with data layout}}
+  "test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>>} : () -> ()
+}

diff  --git a/mlir/test/IR/module-op.mlir b/mlir/test/IR/module-op.mlir
index 2e5bd6f9685e6..b610c0076ac2d 100644
--- a/mlir/test/IR/module-op.mlir
+++ b/mlir/test/IR/module-op.mlir
@@ -55,3 +55,12 @@ module @foo {
     }
   }
 }
+
+// -----
+
+// expected-error at below {{expects at most one data layout attribute}}
+// expected-note at below {{'test.another_attribute' is a data layout attribute}}
+// expected-note at below {{'test.random_attribute' is a data layout attribute}}
+module attributes { test.random_attribute = #dlti.dl_spec<>,
+                    test.another_attribute = #dlti.dl_spec<>} {
+}

diff  --git a/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir
new file mode 100644
index 0000000000000..b6e02c5e388ca
--- /dev/null
+++ b/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --test-data-layout-query %s | FileCheck %s
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+      #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 12]>,
+      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 32]>>} {
+  // CHECK-LABEL: @module_level_layout
+  func @module_level_layout() {
+     // CHECK: alignment = 32
+     // CHECK: preferred = 1
+     // CHECK: size = 12
+    "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
+    return
+  }
+}

diff  --git a/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp b/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp
index c6823a92775f6..76fe79d445ee2 100644
--- a/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Transforms/TestDataLayoutQuery.cpp
@@ -36,8 +36,15 @@ struct TestDataLayoutQuery
             scope, scope ? cast<DataLayoutOpInterface>(scope.getOperation())
                          : nullptr);
       }
+      auto module = op->getParentOfType<ModuleOp>();
+      if (!layouts.count(module))
+        layouts.try_emplace(module, module);
 
-      const DataLayout &layout = layouts.find(scope)->getSecond();
+      Operation *closest = (scope && module && module->isProperAncestor(scope))
+                               ? scope.getOperation()
+                               : module.getOperation();
+
+      const DataLayout &layout = layouts.find(closest)->getSecond();
       unsigned size = layout.getTypeSize(op.getType());
       unsigned alignment = layout.getTypeABIAlignment(op.getType());
       unsigned preferred = layout.getTypePreferredAlignment(op.getType());


        


More information about the Mlir-commits mailing list