[Mlir-commits] [mlir] 2ff6fad - Revert "[mlir] make the bitwidth of device side index computations configurable"
Tobias Gysi
llvmlistbot at llvm.org
Tue Jun 23 10:27:43 PDT 2020
Author: Tobias Gysi
Date: 2020-06-23T19:21:36+02:00
New Revision: 2ff6fad70049b340eeb7cb281c1466fc0169fd17
URL: https://github.com/llvm/llvm-project/commit/2ff6fad70049b340eeb7cb281c1466fc0169fd17
DIFF: https://github.com/llvm/llvm-project/commit/2ff6fad70049b340eeb7cb281c1466fc0169fd17.diff
LOG: Revert "[mlir] make the bitwidth of device side index computations configurable"
This reverts commit d10b1a38a7dfb994623f27f263b67f5fc76e08cc.
Added:
Modified:
mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 1af13057f2ea..5dbfce9bd00f 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -8,7 +8,6 @@
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include <memory>
namespace mlir {
@@ -25,11 +24,9 @@ class GPUModuleOp;
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
-/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
-/// index bitwidth used for the lowering of the device side index computations
-/// is configurable.
-std::unique_ptr<OperationPass<gpu::GPUModuleOp>> createLowerGpuOpsToNVVMOpsPass(
- unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
+/// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
+std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
+createLowerGpuOpsToNVVMOpsPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 677782b2dc67..1722ae628e88 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -8,7 +8,6 @@
#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include <memory>
namespace mlir {
@@ -26,12 +25,9 @@ class GPUModuleOp;
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
-/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The
-/// index bitwidth used for the lowering of the device side index computations
-/// is configurable.
+/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-createLowerGpuOpsToROCDLOpsPass(
- unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
+createLowerGpuOpsToROCDLOpsPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index c75049ee81e3..48149ced5403 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -100,11 +100,6 @@ def ConvertGpuLaunchFuncToGpuRuntimeCalls : Pass<"launch-func-to-gpu-runtime",
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
let summary = "Generate NVVM operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
- let options = [
- Option<"indexBitwidth", "index-bitwidth", "unsigned",
- /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
- "Bitwidth of the index type, 0 to use size of machine word">
- ];
}
//===----------------------------------------------------------------------===//
@@ -114,11 +109,6 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
let summary = "Generate ROCDL operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
- let options = [
- Option<"indexBitwidth", "index-bitwidth", "unsigned",
- /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
- "Bitwidth of the index type, 0 to use size of machine word">
- ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 5c30fbb89925..a7e4ff2f52cf 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -15,7 +15,6 @@
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace llvm {
@@ -36,6 +35,22 @@ class LLVMDialect;
class LLVMType;
} // namespace LLVM
+/// Set of callbacks that allows the customization of LLVMTypeConverter.
+struct LLVMTypeConverterCustomization {
+ using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
+ SmallVectorImpl<Type> &)>;
+
+ /// Customize the type conversion of function arguments.
+ CustomCallback funcArgConverter;
+
+ /// Used to determine the bitwidth of the LLVM integer type that the index
+ /// type gets lowered to. Defaults to deriving the size from the data layout.
+ unsigned indexBitwidth;
+
+ /// Initialize customization to default callbacks.
+ LLVMTypeConverterCustomization();
+};
+
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
@@ -60,11 +75,13 @@ class LLVMTypeConverter : public TypeConverter {
public:
using TypeConverter::convertType;
- /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
+ /// Create an LLVMTypeConverter using the default
+ /// LLVMTypeConverterCustomization.
LLVMTypeConverter(MLIRContext *ctx);
- /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
- LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
+ /// Create an LLVMTypeConverter using 'custom' customizations.
+ LLVMTypeConverter(MLIRContext *ctx,
+ const LLVMTypeConverterCustomization &custom);
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
@@ -110,7 +127,7 @@ class LLVMTypeConverter : public TypeConverter {
LLVM::LLVMType getIndexType();
/// Gets the bitwidth of the index type when converted to LLVM.
- unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
+ unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; }
protected:
/// LLVM IR module used to parse/create types.
@@ -176,8 +193,8 @@ class LLVMTypeConverter : public TypeConverter {
// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type);
- /// Options for customizing the llvm lowering.
- LowerToLLVMOptions options;
+ /// Callbacks for customizing the type conversion.
+ LLVMTypeConverterCustomization customizations;
};
/// Helper class to produce LLVM dialect operations extracting or inserting
@@ -372,17 +389,11 @@ class UnrankedMemRefDescriptor : public StructBuilder {
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
-/// conversion patterns with access to an LLVMTypeConverter and the
-/// LowerToLLVMOptions.
+/// conversion patterns with access to an LLVMTypeConverter.
class ConvertToLLVMPattern : public ConversionPattern {
public:
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
LLVMTypeConverter &typeConverter,
- const LowerToLLVMOptions &options = {
- /*useBarePtrCallConv=*/false,
- /*emitCWrappers=*/false,
- /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
- /*useAlignedAlloc=*/false},
PatternBenefit benefit = 1);
/// Returns the LLVM dialect.
@@ -434,9 +445,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;
-
- /// Reference to the llvm lowering options.
- const LowerToLLVMOptions &options;
};
/// Utility class for operation conversions targeting the LLVM dialect that
@@ -445,11 +453,10 @@ template <typename OpTy>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
- const LowerToLLVMOptions &options,
PatternBenefit benefit = 1)
: ConvertToLLVMPattern(OpTy::getOperationName(),
&typeConverter.getContext(), typeConverter,
- options, benefit) {}
+ benefit) {}
};
namespace LLVM {
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 75bc8eb08886..5479f189b73c 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -14,50 +14,54 @@
namespace mlir {
class LLVMTypeConverter;
class ModuleOp;
-template <typename T>
-class OperationPass;
+template <typename T> class OperationPass;
class OwningRewritePatternList;
-/// Value to pass as bitwidth for the index type when the converter is expected
-/// to derive the bitwidth from the LLVM data layout.
-static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
-
-struct LowerToLLVMOptions {
- bool useBarePtrCallConv = false;
- bool emitCWrappers = false;
- unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
- /// Use aligned_alloc for heap allocations.
- bool useAlignedAlloc = false;
-};
-
/// Collect a set of patterns to convert memory-related operations from the
/// Standard dialect to the LLVM dialect, excluding non-memory-related
/// operations and FuncOp.
void populateStdToLLVMMemoryConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options);
+ bool useAlignedAlloc);
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
/// dialect, excluding the memory-related operations.
void populateStdToLLVMNonMemoryConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options);
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
-void populateStdToLLVMFuncOpConversionPattern(
+void populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options);
+ bool emitCWrappers = false);
-/// Collect the patterns to convert from the Standard dialect to LLVM.
-void populateStdToLLVMConversionPatterns(
+/// Collect a set of default patterns to convert from the Standard dialect to
+/// LLVM.
+void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns,
+ bool emitCWrappers = false,
+ bool useAlignedAlloc = false);
+
+/// Collect a set of patterns to convert from the Standard dialect to
+/// LLVM using the bare pointer calling convention for MemRef function
+/// arguments.
+void populateStdToLLVMBarePtrConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options = {
- /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false,
- /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
- /*useAlignedAlloc=*/false});
+ bool useAlignedAlloc);
+
+/// Value to pass as bitwidth for the index type when the converter is expected
+/// to derive the bitwidth from the LLVM data layout.
+static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
+
+struct LowerToLLVMOptions {
+ bool useBarePtrCallConv = false;
+ bool emitCWrappers = false;
+ unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
+ /// Use aligned_alloc for heap allocations.
+ bool useAlignedAlloc = false;
+};
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
/// stdlib malloc/free is used by default for allocating memrefs allocated with
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 0f5691460ee1..e4fabe4f441e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -30,6 +30,7 @@ using namespace mlir;
namespace {
+
struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
@@ -96,27 +97,17 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
///
/// This pass only handles device code and is not meant to be run on GPU host
/// code.
-struct LowerGpuOpsToNVVMOpsPass
+class LowerGpuOpsToNVVMOpsPass
: public ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
- LowerGpuOpsToNVVMOpsPass() = default;
- LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
- this->indexBitwidth = indexBitwidth;
- }
-
+public:
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
- /// Customize the bitwidth used for the device side index computations.
- LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
- /*emitCWrappers = */ true,
- /*indexBitwidth =*/indexBitwidth,
- /*useAlignedAlloc =*/false};
-
/// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
/// space 5 for private memory attributions, but NVVM represents private
/// memory allocations as local `alloca`s in the default address space. This
/// converter drops the private memory space to support the use case above.
- LLVMTypeConverter converter(m.getContext(), options);
+ LLVMTypeConverter converter(m.getContext());
converter.addConversion([&](MemRefType type) -> Optional<Type> {
if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
return llvm::None;
@@ -185,6 +176,6 @@ void mlir::populateGpuToNVVMConversionPatterns(
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
- return std::make_unique<LowerGpuOpsToNVVMOpsPass>(indexBitwidth);
+mlir::createLowerGpuOpsToNVVMOpsPass() {
+ return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
}
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index f35ea6d2b3c7..2381d615f91b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -41,22 +41,13 @@ namespace {
//
// This pass only handles device code and is not meant to be run on GPU host
// code.
-struct LowerGpuOpsToROCDLOpsPass
+class LowerGpuOpsToROCDLOpsPass
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
- LowerGpuOpsToROCDLOpsPass() = default;
- LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
- this->indexBitwidth = indexBitwidth;
- }
-
+public:
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
- /// Customize the bitwidth used for the device side index computations.
- LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
- /*emitCWrappers = */ true,
- /*indexBitwidth =*/indexBitwidth,
- /*useAlignedAlloc =*/false};
- LLVMTypeConverter converter(m.getContext(), options);
+ LLVMTypeConverter converter(m.getContext());
OwningRewritePatternList patterns;
@@ -115,6 +106,6 @@ void mlir::populateGpuToROCDLConversionPatterns(
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
- return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth);
+mlir::createLowerGpuOpsToROCDLOpsPass() {
+ return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
}
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 202a04c425d6..19c451fa3fe9 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -51,6 +51,11 @@ static LLVM::LLVMType unwrap(Type type) {
return wrappedLLVMType;
}
+/// Initialize customization to default callbacks.
+LLVMTypeConverterCustomization::LLVMTypeConverterCustomization()
+ : funcArgConverter(structFuncArgTypeConverter),
+ indexBitwidth(kDeriveIndexBitwidthFromDataLayout) {}
+
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
@@ -117,19 +122,20 @@ LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
return success();
}
-/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
+/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
- : LLVMTypeConverter(ctx, LowerToLLVMOptions()) {}
+ : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {}
-/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
-LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
- const LowerToLLVMOptions &options_)
+/// Create an LLVMTypeConverter using 'custom' customizations.
+LLVMTypeConverter::LLVMTypeConverter(
+ MLIRContext *ctx, const LLVMTypeConverterCustomization &customs)
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
- options(options_) {
+ customizations(customs) {
assert(llvmDialect && "LLVM IR dialect is not registered");
module = &llvmDialect->getLLVMModule();
- if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
- options.indexBitwidth = module->getDataLayout().getPointerSizeInBits();
+ if (customizations.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
+ customizations.indexBitwidth =
+ module->getDataLayout().getPointerSizeInBits();
// Register conversions for the standard types.
addConversion([&](ComplexType type) { return convertComplexType(type); });
@@ -256,15 +262,11 @@ SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
FunctionType type, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
- // Select the argument converter depending on the calling convetion.
- auto funcArgConverter = options.useBarePtrCallConv
- ? barePtrFuncArgTypeConverter
- : structFuncArgTypeConverter;
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(type.getInputs())) {
Type type = en.value();
SmallVector<Type, 8> converted;
- if (failed(funcArgConverter(*this, type, converted)))
+ if (failed(customizations.funcArgConverter(*this, type, converted)))
return {};
result.addInputs(en.index(), converted);
}
@@ -395,10 +397,9 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
MLIRContext *context,
LLVMTypeConverter &typeConverter_,
- const LowerToLLVMOptions &options_,
PatternBenefit benefit)
: ConversionPattern(rootOpName, benefit, typeConverter_, context),
- typeConverter(typeConverter_), options(options_) {}
+ typeConverter(typeConverter_) {}
/*============================================================================*/
/* StructBuilder implementation */
@@ -1050,10 +1051,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
/// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
- FuncOpConversion(LLVMTypeConverter &converter,
- const LowerToLLVMOptions &options)
- : FuncOpConversionBase(converter, options) {}
- using ConvertOpToLLVMPattern<FuncOp>::options;
+ FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers)
+ : FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1064,7 +1063,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
if (!newFuncOp)
return failure();
- if (options.emitCWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
+ if (emitWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
if (newFuncOp.isExternal())
wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
newFuncOp);
@@ -1076,6 +1075,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
rewriter.eraseOp(op);
return success();
}
+
+private:
+ /// If true, also create the adaptor functions having signatures compatible
+ /// with those produced by clang.
+ const bool emitWrappers;
};
/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
@@ -1502,11 +1506,11 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
using ConvertOpToLLVMPattern<AllocLikeOp>::getIndexType;
using ConvertOpToLLVMPattern<AllocLikeOp>::typeConverter;
using ConvertOpToLLVMPattern<AllocLikeOp>::getVoidPtrType;
- using ConvertOpToLLVMPattern<AllocLikeOp>::options;
explicit AllocLikeOpLowering(LLVMTypeConverter &converter,
- const LowerToLLVMOptions &options)
- : ConvertOpToLLVMPattern<AllocLikeOp>(converter, options) {}
+ bool useAlignedAlloc = false)
+ : ConvertOpToLLVMPattern<AllocLikeOp>(converter),
+ useAlignedAlloc(useAlignedAlloc) {}
LogicalResult match(Operation *op) const override {
MemRefType memRefType = cast<AllocLikeOp>(op).getType();
@@ -1673,7 +1677,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
/// allocation size to be a multiple of alignment,
Optional<int64_t> getAllocationAlignment(AllocOp allocOp) const {
// No alignment can be used for the 'malloc' call itself.
- if (!options.useAlignedAlloc)
+ if (!useAlignedAlloc)
return None;
if (allocOp.alignment())
@@ -1845,14 +1849,16 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
}
protected:
+ /// Use aligned_alloc instead of malloc for all heap allocations.
+ bool useAlignedAlloc;
/// The minimum alignment to use with aligned_alloc (has to be a power of 2).
uint64_t kMinAlignedAllocAlignment = 16UL;
};
struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
explicit AllocOpLowering(LLVMTypeConverter &converter,
- const LowerToLLVMOptions &options)
- : AllocLikeOpLowering<AllocOp>(converter, options) {}
+ bool useAlignedAlloc = false)
+ : AllocLikeOpLowering<AllocOp>(converter, useAlignedAlloc) {}
};
using AllocaOpLowering = AllocLikeOpLowering<AllocaOp>;
@@ -1933,9 +1939,8 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
- explicit DeallocOpLowering(LLVMTypeConverter &converter,
- const LowerToLLVMOptions &options)
- : ConvertOpToLLVMPattern<DeallocOp>(converter, options) {}
+ explicit DeallocOpLowering(LLVMTypeConverter &converter)
+ : ConvertOpToLLVMPattern<DeallocOp>(converter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2955,8 +2960,7 @@ struct GenericAtomicRMWOpLowering
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
void mlir::populateStdToLLVMNonMemoryConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options) {
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
// clang-format off
patterns.insert<
@@ -3019,13 +3023,13 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
UnsignedRemIOpLowering,
UnsignedShiftRightOpLowering,
XOrOpLowering,
- ZeroExtendIOpLowering>(converter, options);
+ ZeroExtendIOpLowering>(converter);
// clang-format on
}
void mlir::populateStdToLLVMMemoryConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options) {
+ bool useAlignedAlloc) {
// clang-format off
patterns.insert<
AssumeAlignmentOpLowering,
@@ -3035,26 +3039,41 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
MemRefCastOpLowering,
StoreOpLowering,
SubViewOpLowering,
- ViewOpLowering,
- AllocOpLowering>(converter, options);
+ ViewOpLowering>(converter);
+ patterns.insert<
+ AllocOpLowering
+ >(converter, useAlignedAlloc);
// clang-format on
}
-void mlir::populateStdToLLVMFuncOpConversionPattern(
+void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options) {
- if (options.useBarePtrCallConv)
- patterns.insert<BarePtrFuncOpConversion>(converter, options);
- else
- patterns.insert<FuncOpConversion>(converter, options);
+ bool emitCWrappers) {
+ patterns.insert<FuncOpConversion>(converter, emitCWrappers);
}
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- const LowerToLLVMOptions &options) {
- populateStdToLLVMFuncOpConversionPattern(converter, patterns, options);
- populateStdToLLVMNonMemoryConversionPatterns(converter, patterns, options);
- populateStdToLLVMMemoryConversionPatterns(converter, patterns, options);
+ bool emitCWrappers, bool useAlignedAlloc) {
+ populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
+ emitCWrappers);
+ populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
+ populateStdToLLVMMemoryConversionPatterns(converter, patterns,
+ useAlignedAlloc);
+}
+
+static void populateStdToLLVMBarePtrFuncOpConversionPattern(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ patterns.insert<BarePtrFuncOpConversion>(converter);
+}
+
+void mlir::populateStdToLLVMBarePtrConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+ bool useAlignedAlloc) {
+ populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
+ populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
+ populateStdToLLVMMemoryConversionPatterns(converter, patterns,
+ useAlignedAlloc);
}
// Create an LLVM IR structure type if there is more than one result.
@@ -3144,12 +3163,19 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
ModuleOp m = getOperation();
- LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers,
- indexBitwidth, useAlignedAlloc};
- LLVMTypeConverter typeConverter(&getContext(), options);
+ LLVMTypeConverterCustomization customs;
+ customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
+ : structFuncArgTypeConverter;
+ customs.indexBitwidth = indexBitwidth;
+ LLVMTypeConverter typeConverter(&getContext(), customs);
OwningRewritePatternList patterns;
- populateStdToLLVMConversionPatterns(typeConverter, patterns, options);
+ if (useBarePtrCallConv)
+ populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
+ useAlignedAlloc);
+ else
+ populateStdToLLVMConversionPatterns(typeConverter, patterns,
+ emitCWrappers, useAlignedAlloc);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, patterns)))
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c4e39c75cd6f..20d166bab05d 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1,52 +1,36 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
gpu.module @test_module {
// CHECK-LABEL: func @gpu_index_ops()
- // CHECK32-LABEL: func @gpu_index_ops()
func @gpu_index_ops()
-> (index, index, index, index, index, index,
index, index, index, index, index, index) {
- // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
-
// CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
@@ -58,21 +42,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK-LABEL: func @gpu_index_comp
- // CHECK32-LABEL: func @gpu_index_comp
- func @gpu_index_comp(%idx : index) -> index {
- // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
- // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
- %0 = addi %idx, %idx : index
- // CHECK: llvm.return %{{.*}} : !llvm.i64
- // CHECK32: llvm.return %{{.*}} : !llvm.i32
- std.return %0 : index
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK-LABEL: func @gpu_all_reduce_op()
gpu.func @gpu_all_reduce_op() {
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index a7565bb6e323..61becff83c6c 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -1,52 +1,36 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
gpu.module @test_module {
// CHECK-LABEL: func @gpu_index_ops()
- // CHECK32-LABEL: func @gpu_index_ops()
func @gpu_index_ops()
-> (index, index, index, index, index, index,
index, index, index, index, index, index) {
- // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
-
// CHECK: rocdl.workitem.id.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
// CHECK: rocdl.workitem.id.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
// CHECK: rocdl.workitem.id.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
// CHECK: rocdl.workgroup.dim.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
// CHECK: rocdl.workgroup.dim.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
// CHECK: rocdl.workgroup.dim.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
// CHECK: rocdl.workgroup.id.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
// CHECK: rocdl.workgroup.id.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
// CHECK: rocdl.workgroup.id.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
// CHECK: rocdl.grid.dim.x : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
// CHECK: rocdl.grid.dim.y : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
// CHECK: rocdl.grid.dim.z : !llvm.i32
- // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
@@ -58,21 +42,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK-LABEL: func @gpu_index_comp
- // CHECK32-LABEL: func @gpu_index_comp
- func @gpu_index_comp(%idx : index) -> index {
- // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
- // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
- %0 = addi %idx, %idx : index
- // CHECK: llvm.return %{{.*}} : !llvm.i64
- // CHECK32: llvm.return %{{.*}} : !llvm.i32
- std.return %0 : index
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK-LABEL: func @gpu_sync()
func @gpu_sync() {
More information about the Mlir-commits
mailing list