[Mlir-commits] [mlir] b3386a7 - [mlir] introduce data layout entry for index type

Alex Zinenko llvmlistbot at llvm.org
Wed Mar 24 07:13:55 PDT 2021


Author: Alex Zinenko
Date: 2021-03-24T15:13:42+01:00
New Revision: b3386a734e430be967e85ab2fb980eeea927ade8

URL: https://github.com/llvm/llvm-project/commit/b3386a734e430be967e85ab2fb980eeea927ade8
DIFF: https://github.com/llvm/llvm-project/commit/b3386a734e430be967e85ab2fb980eeea927ade8.diff

LOG: [mlir] introduce data layout entry for index type

Index type is an integer type of target-specific bitwidth present in many MLIR
operations (loops, memory accesses). Converting values of this type to
fixed-size integers has always been problematic. Introduce a data layout entry
to specify the bitwidth of `index` in a given layout scope, defaulting to 64
bits, which is a commonly used assumption, e.g., in constants.

Port builtin-to-LLVM type conversion to use this data layout entry when
converting `index` type and untie it from pointer size. This is particularly
relevant for GPU targets. Keep a possibility to forcibly override the index
type in lowerings.

Depends On D98525

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D98937

Added: 
    mlir/test/Interfaces/DataLayoutInterfaces/types.mlir

Modified: 
    mlir/docs/DataLayout.md
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
    mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Interfaces/DataLayoutInterfaces.cpp
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DataLayout.md b/mlir/docs/DataLayout.md
index 66222dfdff18..6d2aab864669 100644
--- a/mlir/docs/DataLayout.md
+++ b/mlir/docs/DataLayout.md
@@ -253,6 +253,24 @@ with the
 [modeling of n-D vectors](https://mlir.llvm.org/docs/Dialects/Vector/#deeperdive).
 They **may change** in the future.
 
+#### `index` type
+
+Index type is an integer type used for target-specific size information in,
+e.g., `memref` operations. Its data layout is parameterized by a single integer
+data layout entry that specifies its bitwidth. For example,
+
+```
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<index, 32>
+>} {}
+```
+
+specifies that `index` has 32 bits. All other layout properties of `index` match
+those of the integer type with the same bitwidth defined above.
+
+In absence of the corresponding entry, `index` is assumed to be a 64-bit
+integer.
+
 ### Byte Size
 
 The default data layout assumes 8-bit bytes.

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 84052c6676e4..43a1afcc3f89 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -129,7 +129,7 @@ class LLVMTypeConverter : public TypeConverter {
   Type getIndexType();
 
   /// Gets the bitwidth of the index type when converted to LLVM.
-  unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
+  unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); }
 
   /// Gets the pointer bitwidth.
   unsigned getPointerBitwidth(unsigned addressSpace = 0);

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 1d14fb9d0fd2..a7a68ef0a9c6 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -14,7 +14,9 @@
 #include <memory>
 
 namespace mlir {
+class DataLayout;
 class LLVMTypeConverter;
+class MLIRContext;
 class ModuleOp;
 template <typename T>
 class OperationPass;
@@ -27,10 +29,14 @@ static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
 
 /// Options to control the Standard dialect to LLVM lowering. The struct is used
 /// to share lowering options between passes, patterns, and type converter.
-struct LowerToLLVMOptions {
+class LowerToLLVMOptions {
+public:
+  explicit LowerToLLVMOptions(MLIRContext *ctx);
+  explicit LowerToLLVMOptions(MLIRContext *ctx, DataLayout dl);
+
   bool useBarePtrCallConv = false;
   bool emitCWrappers = false;
-  unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
+
   /// Use aligned_alloc for heap allocations.
   bool useAlignedAlloc = false;
 
@@ -39,11 +45,18 @@ struct LowerToLLVMOptions {
   // TODO: this should be replaced by MLIR data layout when one exists.
   llvm::DataLayout dataLayout = llvm::DataLayout("");
 
-  /// Get a statically allocated copy of the default LowerToLLVMOptions.
-  static const LowerToLLVMOptions &getDefaultOptions() {
-    static LowerToLLVMOptions options;
-    return options;
+  /// Set the index bitwidth to the given value.
+  void overrideIndexBitwidth(unsigned bitwidth) {
+    assert(bitwidth != kDeriveIndexBitwidthFromDataLayout &&
+           "can only override to a concrete bitwidth");
+    indexBitwidth = bitwidth;
   }
+
+  /// Get the index bitwidth.
+  unsigned getIndexBitwidth() const { return indexBitwidth; }
+
+private:
+  unsigned indexBitwidth;
 };
 
 /// Collect a set of patterns to convert memory-related operations from the
@@ -75,9 +88,9 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
 /// stdlib malloc/free is used by default for allocating memrefs allocated with
 /// memref.alloc, while LLVM's alloca is used for those allocated with
 /// memref.alloca.
+std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
 std::unique_ptr<OperationPass<ModuleOp>>
-createLowerToLLVMPass(const LowerToLLVMOptions &options =
-                          LowerToLLVMOptions::getDefaultOptions());
+createLowerToLLVMPass(const LowerToLLVMOptions &options);
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 1e43ebeb55be..26ab17172714 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_GPU_GPUDIALECT_H
 #define MLIR_DIALECT_GPU_GPUDIALECT_H
 
+#include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 5fb0793030b0..41206af46ae2 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -13,9 +13,11 @@
 #ifndef GPU_OPS
 #define GPU_OPS
 
+include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/GPU/GPUBase.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -694,7 +696,8 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
 }
 
 def GPU_GPUModuleOp : GPU_Op<"module", [
-  IsolatedFromAbove, SymbolTable, Symbol,
+  DataLayoutOpInterface, HasDefaultDLTIDataLayout, IsolatedFromAbove,
+  SymbolTable, Symbol,
   SingleBlockImplicitTerminator<"ModuleEndOp">
 ]> {
   let summary = "A top level compilation unit containing code to be run on a GPU.";

diff  --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 8329bdb103e2..87cac054c55e 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -140,6 +140,7 @@ class DataLayoutDialectInterface
 /// mode, the cache validity is being checked in every request.
 class DataLayout {
 public:
+  explicit DataLayout();
   explicit DataLayout(DataLayoutOpInterface op);
   explicit DataLayout(ModuleOp op);
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d5f89f7e7095..e96934088f61 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -108,10 +108,12 @@ struct LowerGpuOpsToNVVMOpsPass
     gpu::GPUModuleOp m = getOperation();
 
     /// Customize the bitwidth used for the device side index computations.
-    LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
-                                  /*emitCWrappers =*/true,
-                                  /*indexBitwidth =*/indexBitwidth,
-                                  /*useAlignedAlloc =*/false};
+    LowerToLLVMOptions options(
+        m.getContext(),
+        DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
+    options.emitCWrappers = true;
+    if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+      options.overrideIndexBitwidth(indexBitwidth);
 
     /// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
     /// space 5 for private memory attributions, but NVVM represents private

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 6cbf3c2798b0..27ad870691bf 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -54,10 +54,12 @@ struct LowerGpuOpsToROCDLOpsPass
     gpu::GPUModuleOp m = getOperation();
 
     /// Customize the bitwidth used for the device side index computations.
-    LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
-                                  /*emitCWrappers =*/true,
-                                  /*indexBitwidth =*/indexBitwidth,
-                                  /*useAlignedAlloc =*/false};
+    LowerToLLVMOptions options(
+        m.getContext(),
+        DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
+    options.emitCWrappers = true;
+    if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+      options.overrideIndexBitwidth(indexBitwidth);
     LLVMTypeConverter converter(m.getContext(), options);
 
     RewritePatternSet patterns(m.getContext());

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index f10b29a62026..f064a6532c8a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -273,10 +273,8 @@ class LowerHostCodeToLLVM
 
     // Specify options to lower Standard to LLVM and pull in the conversion
     // patterns.
-    LowerToLLVMOptions options = {
-        /*useBarePtrCallConv=*/false,
-        /*emitCWrappers=*/true,
-        /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
+    LowerToLLVMOptions options(module.getContext());
+    options.emitCWrappers = true;
     auto *context = module.getContext();
     RewritePatternSet patterns(context);
     LLVMTypeConverter typeConverter(context, options);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
index e1e13f0b1cc2..13cf5cb16c9f 100644
--- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRDataLayoutInterfaces
   MLIRLLVMIR
   MLIRMath
   MLIRMemRef

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 03251098d5c9..ddfb349cb816 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -101,7 +102,7 @@ LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
 
 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
-    : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {}
+    : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx)) {}
 
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
@@ -109,8 +110,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
     : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
       options(options) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
-  if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
-    this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits();
 
   // Register conversions for the builtin types.
   addConversion([&](ComplexType type) { return convertComplexType(type); });
@@ -4074,9 +4073,13 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
 
     ModuleOp m = getOperation();
 
-    LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers,
-                                  indexBitwidth, useAlignedAlloc,
-                                  llvm::DataLayout(this->dataLayout)};
+    LowerToLLVMOptions options(&getContext(), DataLayout(m));
+    options.useBarePtrCallConv = useBarePtrCallConv;
+    options.emitCWrappers = emitCWrappers;
+    if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+      options.overrideIndexBitwidth(indexBitwidth);
+    options.useAlignedAlloc = useAlignedAlloc;
+    options.dataLayout = llvm::DataLayout(this->dataLayout);
     LLVMTypeConverter typeConverter(&getContext(), options);
 
     RewritePatternSet patterns(&getContext());
@@ -4098,9 +4101,21 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
   this->addIllegalOp<math::TanhOp>();
 }
 
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createLowerToLLVMPass() {
+  return std::make_unique<LLVMLoweringPass>();
+}
+
 std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
   return std::make_unique<LLVMLoweringPass>(
-      options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth,
-      options.useAlignedAlloc, options.dataLayout);
+      options.useBarePtrCallConv, options.emitCWrappers,
+      options.getIndexBitwidth(), options.useAlignedAlloc, options.dataLayout);
+}
+
+mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx)
+    : LowerToLLVMOptions(ctx, DataLayout()) {}
+
+mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx,
+                                             mlir::DataLayout dl) {
+  indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx));
 }

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ea70029c849e..c4895ccb30fa 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -45,6 +45,8 @@ add_mlir_dialect_library(MLIRGPU
 
   LINK_LIBS PUBLIC
   MLIRAsync
+  MLIRDataLayoutInterfaces
+  MLIRDLTI
   MLIREDSC
   MLIRIR
   MLIRMemRef

diff  --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index c91a72344c59..9f5c75a425fb 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -31,6 +31,15 @@ static LLVM_ATTRIBUTE_NORETURN void reportMissingDataLayout(Type type) {
   llvm::report_fatal_error(os.str());
 }
 
+/// Returns the bitwidth of the index type if specified in the param list.
+/// Assumes 64-bit index otherwise.
+static unsigned getIndexBitwidth(DataLayoutEntryListRef params) {
+  if (params.empty())
+    return 64;
+  auto attr = params.front().getValue().cast<IntegerAttr>();
+  return attr.getValue().getZExtValue();
+}
+
 unsigned
 mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout,
                                  ArrayRef<DataLayoutEntryInterface> params) {
@@ -44,6 +53,11 @@ unsigned mlir::detail::getDefaultTypeSizeInBits(Type type,
   if (type.isa<IntegerType, FloatType>())
     return type.getIntOrFloatBitWidth();
 
+  // Index is an integer of some bitwidth.
+  if (type.isa<IndexType>())
+    return dataLayout.getTypeSizeInBits(
+        IntegerType::get(type.getContext(), getIndexBitwidth(params)));
+
   // Sizes of vector types are rounded up to those of types with closest
   // power-of-two number of elements in the innermost dimension. We also assume
   // there is no bit-packing at the moment element sizes are taken in bytes and
@@ -67,6 +81,11 @@ unsigned mlir::detail::getDefaultABIAlignment(
   if (type.isa<FloatType, VectorType>())
     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
 
+  // Index is an integer of some bitwidth.
+  if (type.isa<IndexType>())
+    return dataLayout.getTypeABIAlignment(
+        IntegerType::get(type.getContext(), getIndexBitwidth(params)));
+
   if (auto intType = type.dyn_cast<IntegerType>()) {
     return intType.getWidth() < 64
                ? llvm::PowerOf2Ceil(llvm::divideCeil(intType.getWidth(), 8))
@@ -88,7 +107,7 @@ unsigned mlir::detail::getDefaultPreferredAlignment(
 
   // Preferred alignment is the cloest power-of-two number above for integers
   // (ABI alignment may be smaller).
-  if (auto intType = type.dyn_cast<IntegerType>())
+  if (type.isa<IntegerType, IndexType>())
     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
 
   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
@@ -227,6 +246,8 @@ void checkMissingLayout(DataLayoutSpecInterface originalLayout, OpTy op) {
   }
 }
 
+mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
+
 mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
     : originalLayout(getCombinedDataLayout(op)), scope(op) {
 #ifndef NDEBUG
@@ -355,6 +376,16 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
 
   for (const auto &kvp : types) {
     auto sampleType = kvp.second.front().getKey().get<Type>();
+    if (sampleType.isa<IndexType>()) {
+      assert(kvp.second.size() == 1 &&
+             "expected one data layout entry for non-parametric 'index' type");
+      if (!kvp.second.front().getValue().isa<IntegerAttr>())
+        return emitError(loc)
+               << "expected integer attribute in the data layout entry for "
+               << sampleType;
+      continue;
+    }
+
     if (isa<BuiltinDialect>(&sampleType.getDialect()))
       return emitError(loc) << "unexpected data layout for a built-in type";
 

diff  --git a/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir
new file mode 100644
index 000000000000..02adc7f54544
--- /dev/null
+++ b/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics --test-data-layout-query %s | FileCheck %s
+
+// expected-error at below {{expected integer attribute in the data layout entry for 'index'}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<index, [32]>>} {
+}
+
+// -----
+
+// CHECK-LABEL: @index
+module @index attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<index, 32>>} {
+  func @query() {
+    // CHECK: bitsize = 32
+    "test.data_layout_query"() : () -> index
+    return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @index_default
+module @index_default {
+  func @query() {
+    // CHECK: bitsize = 64
+    "test.data_layout_query"() : () -> index
+    return
+  }
+}

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 732c966605bf..1ef697768ba2 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -46,10 +46,8 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   modulePM.addPass(spirv::createLowerABIAttributesPass());
   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
-  LowerToLLVMOptions llvmOptions = {
-      /*useBarePtrCallConv =*/false,
-      /*emitCWrappers = */ true,
-      /*indexBitwidth =*/kDeriveIndexBitwidthFromDataLayout};
+  LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
+  llvmOptions.emitCWrappers = true;
   passManager.addPass(createLowerToLLVMPass(llvmOptions));
   passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
   return passManager.run(module);


        


More information about the Mlir-commits mailing list