[Mlir-commits] [mlir] 5a17780 - [mlir] use unpacked memref descriptors at function boundaries
Alex Zinenko
llvmlistbot at llvm.org
Mon Feb 10 06:03:58 PST 2020
Author: Alex Zinenko
Date: 2020-02-10T15:03:43+01:00
New Revision: 5a1778057f72b8e0444a7932144a3fa441b641bc
URL: https://github.com/llvm/llvm-project/commit/5a1778057f72b8e0444a7932144a3fa441b641bc
DIFF: https://github.com/llvm/llvm-project/commit/5a1778057f72b8e0444a7932144a3fa441b641bc.diff
LOG: [mlir] use unpacked memref descriptors at function boundaries
The existing (default) calling convention for memrefs in standard-to-LLVM
conversion was motivated by interfacing with LLVM IR produced from C sources.
In particular, it passes a pointer to the memref descriptor structure when
calling the function. Therefore, the descriptor is allocated on stack before
the call. This convention leads to several problems. PR44644 indicates a
problem with stack exhaustion when calling functions with memref-typed
arguments in a loop. Allocating outside of the loop may lead to concurrent
access problems in case the loop is parallel. When targeting GPUs, the contents
of the stack-allocated memory for the descriptor (passed by pointer) needs to
be explicitly copied to the device. Using an aggregate type makes it impossible
to attach pointer-specific argument attributes pertaining to alignment and
aliasing in the LLVM dialect.
Change the default calling convention for memrefs in standard-to-LLVM
conversion to transform a memref into a list of arguments, each of primitive
type, that are comprised in the memref descriptor. This avoids stack allocation
for ranked memrefs (and thus stack exhaustion and potential concurrent access
problems) and simplifies the device function invocation on GPUs.
Provide an option in the standard-to-LLVM conversion to generate auxiliary
wrapper function with the same interface as the previous calling convention,
compatible with LLVM IR porduced from C sources. These auxiliary functions
pack the individual values into a descriptor structure or unpack it. They also
handle descriptor stack allocation if necessary, serving as an allocation
scope: the memory reserved by `alloca` will be freed on exiting the auxiliary
function.
The effect of this change on MLIR-generated only LLVM IR is minimal. When
interfacing MLIR-generated LLVM IR with C-generated LLVM IR, the integration
only needs to require auxiliary functions and change the function name to call
the wrapper function instead of the original function.
This also opens the door to forwarding aliasing and alignment information from
memrefs to LLVM IR pointers in the standrd-to-LLVM conversion.
Added:
Modified:
mlir/docs/ConversionToLLVMDialect.md
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir
mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir
mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir
mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/Linalg/llvm.mlir
mlir/test/mlir-cpu-runner/cblas_interface.cpp
mlir/test/mlir-cpu-runner/include/cblas_interface.h
mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
Removed:
################################################################################
diff --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md
index 521e02449341..1e6fdedf933e 100644
--- a/mlir/docs/ConversionToLLVMDialect.md
+++ b/mlir/docs/ConversionToLLVMDialect.md
@@ -248,58 +248,123 @@ func @bar() {
### Calling Convention for `memref`
-For function _arguments_ of `memref` type, ranked or unranked, the type of the
-argument is a _pointer_ to the memref descriptor type defined above. The caller
-of such function is required to store the descriptor in memory and guarantee
-that the storage remains live until the callee returns. The caller can than pass
-the pointer to that memory as function argument. The callee loads from the
-pointers it was passed as arguments in the entry block of the function, making
-the descriptor passed in as argument available for use similarly to
-ocally-defined descriptors.
+Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a
+list of arguments of non-aggregate types that the memref descriptor defined
+above comprises. That is, the outer struct type and the inner array types are
+replaced with individual arguments.
This convention is implemented in the conversion of `std.func` and `std.call` to
-the LLVM dialect. Conversions from other dialects should take it into account.
-The motivation for this convention is to simplify the ABI for interfacing with
-other LLVM modules, in particular those generated from C sources, while avoiding
-platform-specific aspects until MLIR has a proper ABI modeling.
+the LLVM dialect, with the former unpacking the descriptor into a set of
+individual values and the latter packing those values back into a descriptor so
+as to make it transparently usable by other operations. Conversions from other
+dialects should take this convention into account.
-Example:
+This specific convention is motivated by the necessity to specify alignment and
+aliasing attributes on the raw pointers underpinning the memref.
+
+Examples:
```mlir
+func @foo(%arg0: memref<?xf32>) -> () {
+ "use"(%arg0) : (memref<?xf32>) -> ()
+ return
+}
-func @foo(memref<?xf32>) -> () {
- %c0 = constant 0 : index
- load %arg0[%c0] : memref<?xf32>
+// Gets converted to the following.
+
+llvm.func @foo(%arg0: !llvm<"float*">, // Allocated pointer.
+ %arg1: !llvm<"float*">, // Aligned pointer.
+ %arg2: !llvm.i64, // Offset.
+ %arg3: !llvm.i64, // Size in dim 0.
+ %arg4: !llvm.i64) { // Stride in dim 0.
+ // Populate memref descriptor structure.
+ %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+
+ // Descriptor is now usable as a single value.
+ "use"(%5) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">) -> ()
+ llvm.return
+}
+```
+
+```mlir
+func @bar() {
+ %0 = "get"() : () -> (memref<?xf32>)
+ call @foo(%0) : (memref<?xf32>) -> ()
return
}
-func @bar(%arg0: index) {
- %0 = alloc(%arg0) : memref<?xf32>
- call @foo(%0) : (memref<?xf32>)-> ()
+// Gets converted to the following.
+
+llvm.func @bar() {
+ %0 = "get"() : () -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+
+ // Unpack the memref descriptor.
+ %1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %5 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+
+ // Pass individual values to the callee.
+ llvm.call @foo(%1, %2, %3, %4, %5) : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
+ llvm.return
+}
+
+```
+
+For **unranked** memrefs, the list of function arguments always contains two
+elements, same as the unranked memref descriptor: an integer rank, and a
+type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that
+while the _calling convention_ does not require stack allocation, _casting_ to
+unranked memref does since one cannot take an address of an SSA value containing
+the ranked memref. The caller is in charge of ensuring the thread safety and
+eventually removing unnecessary stack allocations in cast operations.
+
+Example
+
+```mlir
+llvm.func @foo(%arg0: memref<*xf32>) -> () {
+ "use"(%arg0) : (memref<*xf32>) -> ()
return
}
-// Gets converted to the following IR.
-// Accepts a pointer to the memref descriptor.
-llvm.func @foo(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) {
- // Loads the descriptor so that it can be used similarly to locally
- // created descriptors.
- %0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
+// Gets converted to the following.
+
+llvm.func @foo(%arg0: !llvm.i64 // Rank.
+ %arg1: !llvm<"i8*">) { // Type-erased pointer to descriptor.
+ // Pack the unranked memref descriptor.
+ %0 = llvm.mlir.undef : !llvm<"{ i64, i8* }">
+ %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ i64, i8* }">
+ %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ i64, i8* }">
+
+ "use"(%2) : (!llvm<"{ i64, i8* }">) -> ()
+ llvm.return
+}
+```
+
+```mlir
+llvm.func @bar() {
+ %0 = "get"() : () -> (memref<*xf32>)
+ call @foo(%0): (memref<*xf32>) -> ()
+ return
}
-llvm.func @bar(%arg0: !llvm.i64) {
- // ... Allocation ...
- // Definition of the descriptor.
- %7 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
- // ... Filling in the descriptor ...
- %14 = // The final value of the allocated descriptor.
- // Allocate the memory for the descriptor and store it.
- %15 = llvm.mlir.constant(1 : index) : !llvm.i64
- %16 = llvm.alloca %15 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
- : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
- llvm.store %14, %16 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
- // Pass the pointer to the function.
- llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
+// Gets converted to the following.
+
+llvm.func @bar() {
+ %0 = "get"() : () -> (!llvm<"{ i64, i8* }">)
+
+ // Unpack the memref descriptor.
+ %1 = llvm.extractvalue %0[0] : !llvm<"{ i64, i8* }">
+ %2 = llvm.extractvalue %0[1] : !llvm<"{ i64, i8* }">
+
+ // Pass individual values to the callee.
+ llvm.call @foo(%1, %2) : (!llvm.i64, !llvm<"i8*">)
llvm.return
}
```
@@ -307,6 +372,141 @@ llvm.func @bar(%arg0: !llvm.i64) {
*This convention may or may not apply if the conversion of MemRef types is
overridden by the user.*
+### C-compatible wrapper emission
+
+In practical cases, it may be desirable to have externally-facing functions
+with a single attribute corresponding to a MemRef argument. When interfacing
+with LLVM IR produced from C, the code needs to respect the corresponding
+calling convention. The conversion to the LLVM dialect provides an option to
+generate wrapper functions that take memref descriptors as pointers-to-struct
+compatible with data types produced by Clang when compiling C sources.
+
+More specifically, a memref argument is converted into a pointer-to-struct
+argument of type `{T*, T*, i64, i64[N], i64[N]}*` in the wrapper function, where
+`T` is the converted element type and `N` is the memref rank. This type is
+compatible with that produced by Clang for the following C++ structure template
+instantiations or their equivalents in C.
+
+```cpp
+template<typename T, size_t N>
+struct MemRefDescriptor {
+ T *allocated;
+ T *aligned;
+ intptr_t offset;
+ intptr_t sizes[N];
+ intptr_t stides[N];
+};
+```
+
+If enabled, the option will do the following. For _external_ functions declared
+in the MLIR module.
+
+1. Declare a new function `_mlir_ciface_<original name>` where memref arguments
+ are converted to pointer-to-struct and the remaining arguments are converted
+ as usual.
+1. Add a body to the original function (making it non-external) that
+ 1. allocates a memref descriptor,
+ 1. populates it, and
+ 1. passes the pointer to it into the newly declared interface function
+ 1. collects the result of the call and returns it to the caller.
+
+For (non-external) functions defined in the MLIR module.
+
+1. Define a new function `_mlir_ciface_<original name>` where memref arguments
+ are converted to pointer-to-struct and the remaining arguments are converted
+ as usual.
+1. Populate the body of the newly defined function with IR that
+ 1. loads descriptors from pointers;
+ 1. unpacks descriptor into individual non-aggregate values;
+ 1. passes these values into the original function;
+ 1. collects the result of the call and returns it to the caller.
+
+Examples:
+
+```mlir
+
+func @qux(%arg0: memref<?x?xf32>)
+
+// Gets converted into the following.
+
+// Function with unpacked arguments.
+llvm.func @qux(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64,
+ %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64,
+ %arg6: !llvm.i64) {
+ // Populate memref descriptor (as per calling convention).
+ %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+
+ // Store the descriptor in a stack-allocated space.
+ %8 = llvm.mlir.constant(1 : index) : !llvm.i64
+ %9 = llvm.alloca %8 x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+ llvm.store %7, %9 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+
+ // Call the interface function.
+ llvm.call @_mlir_ciface_qux(%9) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> ()
+
+ // The stored descriptor will be freed on return.
+ llvm.return
+}
+
+// Interface function.
+llvm.func @_mlir_ciface_qux(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
+```
+
+```mlir
+func @foo(%arg0: memref<?x?xf32>) {
+ return
+}
+
+// Gets converted into the following.
+
+// Function with unpacked arguments.
+llvm.func @foo(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64,
+ %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64,
+ %arg6: !llvm.i64) {
+ llvm.return
+}
+
+// Interface function callable from C.
+llvm.func @_mlir_ciface_foo(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
+ // Load the descriptor.
+ %0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+
+ // Unpack the descriptor as per calling convention.
+ %1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %5 = llvm.extractvalue %0[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %6 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ %7 = llvm.extractvalue %0[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+ llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
+ : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64,
+ !llvm.i64, !llvm.i64) -> ()
+ llvm.return
+}
+```
+
+Rationale: Introducing auxiliary functions for C-compatible interfaces is
+preferred to modifying the calling convention since it will minimize the effect
+of C compatibility on intra-module calls or calls between MLIR-generated
+functions. In particular, when calling external functions from an MLIR module in
+a (parallel) loop, the fact of storing a memref descriptor on stack can lead to
+stack exhaustion and/or concurrent access to the same address. Auxiliary
+interface function serves as an allocation scope in this case. Furthermore, when
+targeting accelerators with separate memory spaces such as GPUs, stack-allocated
+descriptors passed by pointer would have to be transferred to the device memory,
+which introduces significant overhead. In such situations, auxiliary interface
+functions are executed on host and only pass the values through device function
+invocation mechanism.
+
## Repeated Successor Removal
Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 0b8ac9c08cb7..8ab7b17e5458 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -36,8 +36,8 @@ class LLVMType;
/// Set of callbacks that allows the customization of LLVMTypeConverter.
struct LLVMTypeConverterCustomization {
- using CustomCallback =
- std::function<LLVM::LLVMType(LLVMTypeConverter &, Type)>;
+ using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
+ SmallVectorImpl<Type> &)>;
/// Customize the type conversion of function arguments.
CustomCallback funcArgConverter;
@@ -47,19 +47,26 @@ struct LLVMTypeConverterCustomization {
};
/// Callback to convert function argument types. It converts a MemRef function
-/// argument to a struct that contains the descriptor information. Converted
-/// types are promoted to a pointer to the converted type.
-LLVM::LLVMType structFuncArgTypeConverter(LLVMTypeConverter &converter,
- Type type);
+/// argument to a list of non-aggregate types containing descriptor
+/// information, and an UnrankedmemRef function argument to a list containing
+/// the rank and a pointer to a descriptor struct.
+LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
+ Type type,
+ SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
-/// arguments to bare pointers to the MemRef element type. Converted types are
-/// not promoted to pointers.
-LLVM::LLVMType barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
- Type type);
+/// arguments to bare pointers to the MemRef element type.
+LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
+ Type type,
+ SmallVectorImpl<Type> &result);
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
+ /// Give structFuncArgTypeConverter access to memref-specific functions.
+ friend LogicalResult
+ structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
+ SmallVectorImpl<Type> &result);
+
public:
using TypeConverter::convertType;
@@ -107,6 +114,15 @@ class LLVMTypeConverter : public TypeConverter {
Value promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder);
+ /// Converts the function type to a C-compatible format, in particular using
+ /// pointers to memref descriptors for arguments.
+ LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
+
+ /// Creates descriptor structs from individual values constituting them.
+ Operation *materializeConversion(PatternRewriter &rewriter, Type type,
+ ArrayRef<Value> values,
+ Location loc) override;
+
protected:
/// LLVM IR module used to parse/create types.
llvm::Module *module;
@@ -133,14 +149,34 @@ class LLVMTypeConverter : public TypeConverter {
// by LLVM.
Type convertFloatType(FloatType type);
- // Convert a memref type into an LLVM type that captures the relevant data.
- // For statically-shaped memrefs, the resulting type is a pointer to the
- // (converted) memref element type. For dynamically-shaped memrefs, the
- // resulting type is an LLVM structure type that contains:
- // 1. a pointer to the (converted) memref element type
- // 2. as many index types as memref has dynamic dimensions.
+ /// Convert a memref type into an LLVM type that captures the relevant data.
Type convertMemRefType(MemRefType type);
+ /// Convert a memref type into a list of non-aggregate LLVM IR types that
+ /// contain all the relevant data. In particular, the list will contain:
+ /// - two pointers to the memref element type, followed by
+ /// - an integer offset, followed by
+ /// - one integer size per dimension of the memref, followed by
+ /// - one integer stride per dimension of the memref.
+ /// For example, memref<?x?xf32> is converted to the following list:
+ /// - `!llvm<"float*">` (allocated pointer),
+ /// - `!llvm<"float*">` (aligned pointer),
+ /// - `!llvm.i64` (offset),
+ /// - `!llvm.i64`, `!llvm.i64` (sizes),
+ /// - `!llvm.i64`, `!llvm.i64` (strides).
+ /// These types can be recomposed to a memref descriptor struct.
+ SmallVector<Type, 5> convertMemRefSignature(MemRefType type);
+
+ /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
+ /// that contain all the relevant data. In particular, this list contains:
+ /// - an integer rank, followed by
+ /// - a pointer to the memref descriptor struct.
+ /// For example, memref<*xf32> is converted to the following list:
+ /// !llvm.i64 (rank)
+ /// !llvm<"i8*"> (type-erased pointer).
+ /// These types can be recomposed to a unranked memref descriptor struct.
+ SmallVector<Type, 2> convertUnrankedMemRefSignature();
+
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
@@ -180,6 +216,7 @@ class StructBuilder {
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
+
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
@@ -234,11 +271,63 @@ class MemRefDescriptor : public StructBuilder {
/// Returns the (LLVM) type this descriptor points to.
LLVM::LLVMType getElementType();
+ /// Builds IR populating a MemRef descriptor structure from a list of
+ /// individual values composing that descriptor, in the following order:
+ /// - allocated pointer;
+ /// - aligned pointer;
+ /// - offset;
+ /// - <rank> sizes;
+ /// - <rank> shapes;
+ /// where <rank> is the MemRef rank as provided in `type`.
+ static Value pack(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &converter, MemRefType type,
+ ValueRange values);
+
+ /// Builds IR extracting individual elements of a MemRef descriptor structure
+ /// and returning them as `results` list.
+ static void unpack(OpBuilder &builder, Location loc, Value packed,
+ MemRefType type, SmallVectorImpl<Value> &results);
+
+ /// Returns the number of non-aggregate values that would be produced by
+ /// `unpack`.
+ static unsigned getNumUnpackedValues(MemRefType type);
+
private:
// Cached index type.
Type indexType;
};
+/// Helper class allowing the user to access a range of Values that correspond
+/// to an unpacked memref descriptor using named accessors. This does not own
+/// the values.
+class MemRefDescriptorView {
+public:
+ /// Constructs the view from a range of values. Infers the rank from the size
+ /// of the range.
+ explicit MemRefDescriptorView(ValueRange range);
+
+ /// Returns the allocated pointer Value.
+ Value allocatedPtr();
+
+ /// Returns the aligned pointer Value.
+ Value alignedPtr();
+
+ /// Returns the offset Value.
+ Value offset();
+
+ /// Returns the pos-th size Value.
+ Value size(unsigned pos);
+
+ /// Returns the pos-th stride Value.
+ Value stride(unsigned pos);
+
+private:
+ /// Rank of the memref the descriptor is pointing to.
+ int rank;
+ /// Underlying range of Values.
+ ValueRange elements;
+};
+
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
@@ -255,6 +344,23 @@ class UnrankedMemRefDescriptor : public StructBuilder {
Value memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
+
+ /// Builds IR populating an unranked MemRef descriptor structure from a list
+ /// of individual constituent values in the following order:
+ /// - rank of the memref;
+ /// - pointer to the memref descriptor.
+ static Value pack(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &converter, UnrankedMemRefType type,
+ ValueRange values);
+
+ /// Builds IR extracting individual elements that compose an unranked memref
+ /// descriptor and returns them as `results` list.
+ static void unpack(OpBuilder &builder, Location loc, Value packed,
+ SmallVectorImpl<Value> &results);
+
+ /// Returns the number of non-aggregate values that would be produced by
+ /// `unpack`.
+ static unsigned getNumUnpackedValues() { return 2; }
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 4179fff8a7b0..8f319029f7b5 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -29,16 +29,21 @@ void populateStdToLLVMMemoryConversionPatters(
void populateStdToLLVMNonMemoryConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
-/// Collect the default pattern to convert a FuncOp to the LLVM dialect.
+/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
+/// `emitCWrappers` is set, the pattern will also produce functions
+/// that pass memref descriptors by pointer-to-structure in addition to the
+/// default unpacked form.
void populateStdToLLVMDefaultFuncOpConversionPattern(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+ bool emitCWrappers = false);
/// Collect a set of default patterns to convert from the Standard dialect to
/// LLVM. If `useAlloca` is set, the patterns for AllocOp and DeallocOp will
/// generate `llvm.alloca` instead of calls to "malloc".
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns,
- bool useAlloca = false);
+ bool useAlloca = false,
+ bool emitCWrappers = false);
/// Collect a set of patterns to convert from the Standard dialect to
/// LLVM using the bare pointer calling convention for MemRef function
@@ -53,7 +58,7 @@ void populateStdToLLVMBarePtrConversionPatterns(
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
/// this may become an enum when we have concrete uses for other options.
std::unique_ptr<OpPassBase<ModuleOp>>
-createLowerToLLVMPass(bool useAlloca = false);
+createLowerToLLVMPass(bool useAlloca = false, bool emitCWrappers = false);
namespace LLVM {
/// Make argument-taking successors of each block distinct. PHI nodes in LLVM
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 4aca5e3a66a1..98793fecb598 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -30,6 +30,13 @@ inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
return ("arg" + Twine(arg)).toStringRef(out);
}
+/// Returns true if the given name is a valid argument attribute name.
+inline bool isArgAttrName(StringRef name) {
+ APInt unused;
+ return name.startswith("arg") &&
+ !name.drop_front(3).getAsInteger(/*Radix=*/10, unused);
+}
+
/// Return the name of the attribute used for function results.
inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
out.clear();
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index bff5dfbcff34..7a275b42df60 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -113,6 +113,8 @@ class GpuLaunchFuncToCudaCallsPass
}
void declareCudaFunctions(Location loc);
+ void addParamToList(OpBuilder &builder, Location loc, Value param, Value list,
+ unsigned pos, Value one);
Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
Value generateKernelNameConstant(StringRef name, Location loc,
OpBuilder &builder);
@@ -231,6 +233,35 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
}
}
+/// Emits the IR with the following structure:
+///
+/// %data = llvm.alloca 1 x type-of(<param>)
+/// llvm.store <param>, %data
+/// %typeErased = llvm.bitcast %data to !llvm<"i8*">
+/// %addr = llvm.getelementptr <list>[<pos>]
+/// llvm.store %typeErased, %addr
+///
+/// This is necessary to construct the list of arguments passed to the kernel
+/// function as accepted by cuLaunchKernel, i.e. as a void** that points to list
+/// of stack-allocated type-erased pointers to the actual arguments.
+void GpuLaunchFuncToCudaCallsPass::addParamToList(OpBuilder &builder,
+ Location loc, Value param,
+ Value list, unsigned pos,
+ Value one) {
+ auto memLocation = builder.create<LLVM::AllocaOp>(
+ loc, param.getType().cast<LLVM::LLVMType>().getPointerTo(), one,
+ /*alignment=*/1);
+ builder.create<LLVM::StoreOp>(loc, param, memLocation);
+ auto casted =
+ builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+
+ auto index = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+ builder.getI32IntegerAttr(pos));
+ auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), list,
+ ArrayRef<Value>{index});
+ builder.create<LLVM::StoreOp>(loc, casted, gep);
+}
+
// Generates a parameters array to be used with a CUDA kernel launch call. The
// arguments are extracted from the launchOp.
// The generated code is essentially as follows:
@@ -241,53 +272,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
// return %array
Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
OpBuilder &builder) {
+
+ // Get the launch target.
+ auto containingModule = launchOp.getParentOfType<ModuleOp>();
+ if (!containingModule)
+ return {};
+ auto gpuModule = containingModule.lookupSymbol<gpu::GPUModuleOp>(
+ launchOp.getKernelModuleName());
+ if (!gpuModule)
+ return {};
+ auto gpuFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>(launchOp.kernel());
+ if (!gpuFunc)
+ return {};
+
+ unsigned numArgs = gpuFunc.getNumArguments();
+
auto numKernelOperands = launchOp.getNumKernelOperands();
Location loc = launchOp.getLoc();
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(1));
- // Provision twice as much for the `array` to allow up to one level of
- // indirection for each argument.
auto arraySize = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands));
+ loc, getInt32Type(), builder.getI32IntegerAttr(numArgs));
auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
arraySize, /*alignment=*/0);
+
+ unsigned pos = 0;
for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
auto operand = launchOp.getKernelOperand(idx);
auto llvmType = operand.getType().cast<LLVM::LLVMType>();
- Value memLocation = builder.create<LLVM::AllocaOp>(
- loc, llvmType.getPointerTo(), one, /*alignment=*/1);
- builder.create<LLVM::StoreOp>(loc, operand, memLocation);
- auto casted =
- builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
// Assume all struct arguments come from MemRef. If this assumption does not
// hold anymore then we `launchOp` to lower from MemRefType and not after
// LLVMConversion has taken place and the MemRef information is lost.
- // Extra level of indirection in the `array`:
- // the descriptor pointer is registered via @mcuMemHostRegisterPtr
- if (llvmType.isStructTy()) {
- auto registerFunc =
- getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
- auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
- auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
- ArrayRef<Value>{nullPtr, one});
- auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
- builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
- builder.getSymbolRefAttr(registerFunc),
- ArrayRef<Value>{casted, size});
- Value memLocation = builder.create<LLVM::AllocaOp>(
- loc, getPointerPointerType(), one, /*alignment=*/1);
- builder.create<LLVM::StoreOp>(loc, casted, memLocation);
- casted =
- builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+ if (!llvmType.isStructTy()) {
+ addParamToList(builder, loc, operand, array, pos++, one);
+ continue;
}
- auto index = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(idx));
- auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
- ArrayRef<Value>{index});
- builder.create<LLVM::StoreOp>(loc, casted, gep);
+ // Put individual components of a memref descriptor into the flat argument
+ // list. We cannot use unpackMemref from LLVM lowering here because we have
+ // no access to MemRefType that had been lowered away.
+ for (int32_t j = 0, ej = llvmType.getStructNumElements(); j < ej; ++j) {
+ auto elemType = llvmType.getStructElementType(j);
+ if (elemType.isArrayTy()) {
+ for (int32_t k = 0, ek = elemType.getArrayNumElements(); k < ek; ++k) {
+ Value elem = builder.create<LLVM::ExtractValueOp>(
+ loc, elemType.getArrayElementType(), operand,
+ builder.getI32ArrayAttr({j, k}));
+ addParamToList(builder, loc, elem, array, pos++, one);
+ }
+ } else {
+ assert((elemType.isIntegerTy() || elemType.isFloatTy() ||
+ elemType.isDoubleTy() || elemType.isPointerTy()) &&
+ "expected scalar type");
+ Value strct = builder.create<LLVM::ExtractValueOp>(
+ loc, elemType, operand, builder.getI32ArrayAttr(j));
+ addParamToList(builder, loc, strct, array, pos++, one);
+ }
+ }
}
+
return array;
}
@@ -392,6 +436,10 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder);
+ if (!paramsArray) {
+ launchOp.emitOpError() << "cannot pass given parameters to the kernel";
+ return signalPassFailure();
+ }
auto nullpointer =
builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
builder.create<LLVM::CallOp>(
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index ab47dcfb685f..09fbd3095168 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -564,8 +564,8 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
- for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i)
- signatureConversion.addInputs(i, funcType.getFunctionParamType(i));
+ lowering.convertFunctionSignature(gpuFuncOp.getType(), /*isVariadic=*/false,
+ signatureConversion);
// Create the new function operation. Only copy those attributes that are
// not specific to function modeling.
@@ -651,25 +651,6 @@ struct GPUFuncOpLowering : LLVMOpLowering {
rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
signatureConversion);
- {
- // For memref-typed arguments, insert the relevant loads in the beginning
- // of the block to comply with the LLVM dialect calling convention. This
- // needs to be done after signature conversion to get the right types.
- OpBuilder::InsertionGuard guard(rewriter);
- Block &block = llvmFuncOp.front();
- rewriter.setInsertionPointToStart(&block);
-
- for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) {
- if (!en.value().isa<MemRefType>() &&
- !en.value().isa<UnrankedMemRefType>())
- continue;
-
- BlockArgument arg = block.getArgument(en.index());
- Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
- rewriter.replaceUsesOfBlockArgument(arg, loaded);
- }
- }
-
rewriter.eraseOp(gpuFuncOp);
return matchSuccess();
}
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 6baa7dbfe519..90312222e735 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -577,7 +577,8 @@ void ConvertLinalgToLLVMPass::runOnModule() {
LinalgTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
- populateStdToLLVMConversionPatterns(converter, patterns);
+ populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
+ /*emitCWrappers=*/true);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateLinalgToStandardConversionPatterns(patterns, &getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index d0668cad0fe4..63c8b75f197f 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -30,6 +30,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
@@ -43,6 +44,11 @@ static llvm::cl::opt<bool>
llvm::cl::desc("Replace emission of malloc/free by alloca"),
llvm::cl::init(false));
+static llvm::cl::opt<bool>
+ clEmitCWrappers(PASS_NAME "-emit-c-wrappers",
+ llvm::cl::desc("Emit C-compatible wrapper functions"),
+ llvm::cl::init(false));
+
static llvm::cl::opt<bool> clUseBarePtrCallConv(
PASS_NAME "-use-bare-ptr-memref-call-conv",
llvm::cl::desc("Replace FuncOp's MemRef arguments with "
@@ -66,18 +72,32 @@ LLVMTypeConverterCustomization::LLVMTypeConverterCustomization() {
funcArgConverter = structFuncArgTypeConverter;
}
-// Callback to convert function argument types. It converts a MemRef function
-// arguments to a struct that contains the descriptor information. Converted
-// types are promoted to a pointer to the converted type.
-LLVM::LLVMType mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
- Type type) {
- auto converted =
- converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
+/// Callback to convert function argument types. It converts a MemRef function
+/// argument to a list of non-aggregate types containing descriptor
+/// information, and an UnrankedmemRef function argument to a list containing
+/// the rank and a pointer to a descriptor struct.
+LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
+ Type type,
+ SmallVectorImpl<Type> &result) {
+ if (auto memref = type.dyn_cast<MemRefType>()) {
+ auto converted = converter.convertMemRefSignature(memref);
+ if (converted.empty())
+ return failure();
+ result.append(converted.begin(), converted.end());
+ return success();
+ }
+ if (type.isa<UnrankedMemRefType>()) {
+ auto converted = converter.convertUnrankedMemRefSignature();
+ if (converted.empty())
+ return failure();
+ result.append(converted.begin(), converted.end());
+ return success();
+ }
+ auto converted = converter.convertType(type);
if (!converted)
- return {};
- if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>())
- converted = converted.getPointerTo();
- return converted;
+ return failure();
+ result.push_back(converted);
+ return success();
}
/// Convert a MemRef type to a bare pointer to the MemRef element type.
@@ -96,15 +116,26 @@ static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter,
}
/// Callback to convert function argument types. It converts MemRef function
-/// arguments to bare pointers to the MemRef element type. Converted types are
-/// not promoted to pointers.
-LLVM::LLVMType mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
- Type type) {
+/// arguments to bare pointers to the MemRef element type.
+LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
+ Type type,
+ SmallVectorImpl<Type> &result) {
// TODO: Add support for unranked memref.
- if (auto memrefTy = type.dyn_cast<MemRefType>())
- return convertMemRefTypeToBarePtr(converter, memrefTy)
- .dyn_cast_or_null<LLVM::LLVMType>();
- return converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
+ if (auto memrefTy = type.dyn_cast<MemRefType>()) {
+ auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy);
+ if (!llvmTy)
+ return failure();
+
+ result.push_back(llvmTy);
+ return success();
+ }
+
+ auto llvmTy = converter.convertType(type);
+ if (!llvmTy)
+ return failure();
+
+ result.push_back(llvmTy);
+ return success();
}
/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization.
@@ -165,6 +196,33 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
return converted.getPointerTo();
}
+/// In signatures, MemRef descriptors are expanded into lists of non-aggregate
+/// values.
+SmallVector<Type, 5>
+LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
+ SmallVector<Type, 5> results;
+ assert(isStrided(type) &&
+ "Non-strided layout maps must have been normalized away");
+
+ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+ if (!elementType)
+ return {};
+ auto indexTy = getIndexType();
+
+ results.insert(results.begin(), 2,
+ elementType.getPointerTo(type.getMemorySpace()));
+ results.push_back(indexTy);
+ auto rank = type.getRank();
+ results.insert(results.end(), 2 * rank, indexTy);
+ return results;
+}
+
+/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
+/// pointer to descriptor".
+SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
+ return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
+}
+
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
@@ -175,9 +233,8 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(type.getInputs())) {
Type type = en.value();
- auto converted = customizations.funcArgConverter(*this, type)
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!converted)
+ SmallVector<Type, 8> converted;
+ if (failed(customizations.funcArgConverter(*this, type, converted)))
return {};
result.addInputs(en.index(), converted);
}
@@ -199,6 +256,47 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
}
+/// Converts the function type to a C-compatible format, in particular using
+/// pointers to memref descriptors for arguments.
+LLVM::LLVMType
+LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
+ SmallVector<LLVM::LLVMType, 4> inputs;
+
+ for (Type t : type.getInputs()) {
+ auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
+ if (!converted)
+ return {};
+ if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
+ converted = converted.getPointerTo();
+ inputs.push_back(converted);
+ }
+
+ LLVM::LLVMType resultType =
+ type.getNumResults() == 0
+ ? LLVM::LLVMType::getVoidTy(llvmDialect)
+ : unwrap(packFunctionResults(type.getResults()));
+ if (!resultType)
+ return {};
+
+ return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
+}
+
+/// Creates descriptor structs from individual values constituting them.
+Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
+ Type type,
+ ArrayRef<Value> values,
+ Location loc) {
+ if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
+ return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
+ unrankedMemRefType, values)
+ .getDefiningOp();
+
+ auto memRefType = type.dyn_cast<MemRefType>();
+ assert(memRefType && "1->N conversion is only supported for memrefs");
+ return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
+ .getDefiningOp();
+}
+
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by
@@ -473,6 +571,85 @@ LLVM::LLVMType MemRefDescriptor::getElementType() {
kAlignedPtrPosInMemRefDescriptor);
}
+/// Creates a MemRef descriptor structure from a list of individual values
+/// composing that descriptor, in the following order:
+/// - allocated pointer;
+/// - aligned pointer;
+/// - offset;
+/// - <rank> sizes;
+/// - <rank> shapes;
+/// where <rank> is the MemRef rank as provided in `type`.
+Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &converter, MemRefType type,
+ ValueRange values) {
+ Type llvmType = converter.convertType(type);
+ auto d = MemRefDescriptor::undef(builder, loc, llvmType);
+
+ d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
+ d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
+ d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
+
+ int64_t rank = type.getRank();
+ for (unsigned i = 0; i < rank; ++i) {
+ d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
+ d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
+ }
+
+ return d;
+}
+
+/// Builds IR extracting individual elements of a MemRef descriptor structure
+/// and returning them as `results` list.
+void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
+ MemRefType type,
+ SmallVectorImpl<Value> &results) {
+ int64_t rank = type.getRank();
+ results.reserve(results.size() + getNumUnpackedValues(type));
+
+ MemRefDescriptor d(packed);
+ results.push_back(d.allocatedPtr(builder, loc));
+ results.push_back(d.alignedPtr(builder, loc));
+ results.push_back(d.offset(builder, loc));
+ for (int64_t i = 0; i < rank; ++i)
+ results.push_back(d.size(builder, loc, i));
+ for (int64_t i = 0; i < rank; ++i)
+ results.push_back(d.stride(builder, loc, i));
+}
+
+/// Returns the number of non-aggregate values that would be produced by
+/// `unpack`.
+unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
+ // Two pointers, offset, <rank> sizes, <rank> shapes.
+ return 3 + 2 * type.getRank();
+}
+
+/*============================================================================*/
+/* MemRefDescriptorView implementation. */
+/*============================================================================*/
+
+MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
+ : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
+
+Value MemRefDescriptorView::allocatedPtr() {
+ return elements[kAllocatedPtrPosInMemRefDescriptor];
+}
+
+Value MemRefDescriptorView::alignedPtr() {
+ return elements[kAlignedPtrPosInMemRefDescriptor];
+}
+
+Value MemRefDescriptorView::offset() {
+ return elements[kOffsetPosInMemRefDescriptor];
+}
+
+Value MemRefDescriptorView::size(unsigned pos) {
+ return elements[kSizePosInMemRefDescriptor + pos];
+}
+
+Value MemRefDescriptorView::stride(unsigned pos) {
+ return elements[kSizePosInMemRefDescriptor + rank + pos];
+}
+
/*============================================================================*/
/* UnrankedMemRefDescriptor implementation */
/*============================================================================*/
@@ -504,6 +681,34 @@ void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
Location loc, Value v) {
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
}
+
+/// Builds IR populating an unranked MemRef descriptor structure from a list
+/// of individual constituent values in the following order:
+/// - rank of the memref;
+/// - pointer to the memref descriptor.
+Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &converter,
+ UnrankedMemRefType type,
+ ValueRange values) {
+ Type llvmType = converter.convertType(type);
+ auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
+
+ d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
+ d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
+ return d;
+}
+
+/// Builds IR extracting individual elements that compose an unranked memref
+/// descriptor and returns them as `results` list.
+void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
+ Value packed,
+ SmallVectorImpl<Value> &results) {
+ UnrankedMemRefDescriptor d(packed);
+ results.reserve(results.size() + 2);
+ results.push_back(d.rank(builder, loc));
+ results.push_back(d.memRefDescPtr(builder, loc));
+}
+
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
@@ -551,9 +756,144 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
LLVM::LLVMDialect &dialect;
};
+/// Only retain those attributes that are not constructed by
+/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
+/// attributes.
+static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
+ bool filterArgAttrs,
+ SmallVectorImpl<NamedAttribute> &result) {
+ for (const auto &attr : attrs) {
+ if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
+ attr.first.is(impl::getTypeAttrName()) ||
+ attr.first.is("std.varargs") ||
+ (filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
+ continue;
+ result.push_back(attr);
+ }
+}
+
+/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
+/// arguments instead of unpacked arguments. This function can be called from C
+/// by passing a pointer to a C struct corresponding to a memref descriptor.
+/// Internally, the auxiliary function unpacks the descriptor into individual
+/// components and forwards them to `newFuncOp`.
+static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
+ LLVMTypeConverter &typeConverter,
+ FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
+ auto type = funcOp.getType();
+ SmallVector<NamedAttribute, 4> attributes;
+ filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
+ auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+ loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
+ typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
+ attributes);
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
+
+ SmallVector<Value, 8> args;
+ for (auto &en : llvm::enumerate(type.getInputs())) {
+ Value arg = wrapperFuncOp.getArgument(en.index());
+ if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
+ Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
+ MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
+ continue;
+ }
+ if (en.value().isa<UnrankedMemRefType>()) {
+ Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
+ UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
+ continue;
+ }
+
+ args.push_back(wrapperFuncOp.getArgument(en.index()));
+ }
+ auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
+ rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+}
+
+/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
+/// arguments instead of unpacked arguments. Creates a body for the (external)
+/// `newFuncOp` that allocates a memref descriptor on stack, packs the
+/// individual arguments into this descriptor and passes a pointer to it into
+/// the auxiliary function. This auxiliary external function is now compatible
+/// with functions defined in C using pointers to C structs corresponding to a
+/// memref descriptor.
+static void wrapExternalFunction(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
+ OpBuilder::InsertionGuard guard(builder);
+
+ LLVM::LLVMType wrapperType =
+ typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
+ // This conversion can only fail if it could not convert one of the argument
+ // types. But since it has been applies to a non-wrapper function before, it
+ // should have failed earlier and not reach this point at all.
+ assert(wrapperType && "unexpected type conversion failure");
+
+ SmallVector<NamedAttribute, 4> attributes;
+ filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
+
+ // Create the auxiliary function.
+ auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
+ loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
+ wrapperType, LLVM::Linkage::External, attributes);
+
+ builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
+
+ // Get a ValueRange containing argument types. Note that ValueRange is
+ // currently not constructible from a pair of iterators pointing to
+ // BlockArgument.
+ FunctionType type = funcOp.getType();
+ SmallVector<Value, 8> args;
+ args.reserve(type.getNumInputs());
+ auto wrapperArgIters = newFuncOp.getArguments();
+ SmallVector<Value, 8> wrapperArgs(wrapperArgIters.begin(),
+ wrapperArgIters.end());
+ ValueRange wrapperArgsRange(wrapperArgs);
+
+ // Iterate over the inputs of the original function and pack values into
+ // memref descriptors if the original type is a memref.
+ for (auto &en : llvm::enumerate(type.getInputs())) {
+ Value arg;
+ int numToDrop = 1;
+ auto memRefType = en.value().dyn_cast<MemRefType>();
+ auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
+ if (memRefType || unrankedMemRefType) {
+ numToDrop = memRefType
+ ? MemRefDescriptor::getNumUnpackedValues(memRefType)
+ : UnrankedMemRefDescriptor::getNumUnpackedValues();
+ Value packed =
+ memRefType
+ ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
+ wrapperArgsRange.take_front(numToDrop))
+ : UnrankedMemRefDescriptor::pack(
+ builder, loc, typeConverter, unrankedMemRefType,
+ wrapperArgsRange.take_front(numToDrop));
+
+ auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
+ Value one = builder.create<LLVM::ConstantOp>(
+ loc, typeConverter.convertType(builder.getIndexType()),
+ builder.getIntegerAttr(builder.getIndexType(), 1));
+ Value allocated =
+ builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
+ builder.create<LLVM::StoreOp>(loc, packed, allocated);
+ arg = allocated;
+ } else {
+ arg = wrapperArgsRange[0];
+ }
+
+ args.push_back(arg);
+ wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
+ }
+ assert(wrapperArgsRange.empty() && "did not map some of the arguments");
+
+ auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
+ builder.create<LLVM::ReturnOp>(loc, call.getResults());
+}
+
struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
protected:
- using LLVMLegalizationPattern::LLVMLegalizationPattern;
+ using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
using UnsignedTypePair = std::pair<unsigned, Type>;
// Gather the positions and types of memref-typed arguments in a given
@@ -579,14 +919,24 @@ struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
auto llvmType = lowering.convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
- // Only retain those attributes that are not constructed by build.
+ // Propagate argument attributes to all converted arguments obtained after
+ // converting a given original argument.
SmallVector<NamedAttribute, 4> attributes;
- for (const auto &attr : funcOp.getAttrs()) {
- if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
- attr.first.is(impl::getTypeAttrName()) ||
- attr.first.is("std.varargs"))
+ filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true,
+ attributes);
+ for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
+ auto attr = impl::getArgAttrDict(funcOp, i);
+ if (!attr)
continue;
- attributes.push_back(attr);
+
+ auto mapping = result.getInputMapping(i);
+ assert(mapping.hasValue() && "unexpected deletion of function argument");
+
+ SmallString<8> name;
+ for (size_t j = mapping->inputNo; j < mapping->size; ++j) {
+ impl::getArgAttrName(j, name);
+ attributes.push_back(rewriter.getNamedAttr(name, attr));
+ }
}
// Create an LLVM function, use external linkage by default until MLIR
@@ -607,34 +957,33 @@ struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
- using FuncOpConversionBase::FuncOpConversionBase;
+ FuncOpConversion(LLVM::LLVMDialect &dialect, LLVMTypeConverter &converter,
+ bool emitCWrappers)
+ : FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
- // Store the positions of memref-typed arguments so that we can emit loads
- // from them to follow the calling convention.
- SmallVector<UnsignedTypePair, 4> promotedArgsInfo;
- getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo);
-
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
-
- // Insert loads from memref descriptor pointers in function bodies.
- if (!newFuncOp.getBody().empty()) {
- Block *firstBlock = &newFuncOp.getBody().front();
- rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
- for (const auto &argInfo : promotedArgsInfo) {
- BlockArgument arg = firstBlock->getArgument(argInfo.first);
- Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
- rewriter.replaceUsesOfBlockArgument(arg, loaded);
- }
+ if (emitWrappers) {
+ if (newFuncOp.isExternal())
+ wrapExternalFunction(rewriter, op->getLoc(), lowering, funcOp,
+ newFuncOp);
+ else
+ wrapForExternalCallers(rewriter, op->getLoc(), lowering, funcOp,
+ newFuncOp);
}
rewriter.eraseOp(op);
return matchSuccess();
}
+
+private:
+ /// If true, also create the adaptor functions having signatures compatible
+ /// with those produced by clang.
+ const bool emitWrappers;
};
/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
@@ -2273,14 +2622,17 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
}
void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<FuncOpConversion>(*converter.getDialect(), converter);
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+ bool emitCWrappers) {
+ patterns.insert<FuncOpConversion>(*converter.getDialect(), converter,
+ emitCWrappers);
}
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- bool useAlloca) {
- populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns);
+ bool useAlloca, bool emitCWrappers) {
+ populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
+ emitCWrappers);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
}
@@ -2346,13 +2698,20 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
for (auto it : llvm::zip(opOperands, operands)) {
auto operand = std::get<0>(it);
auto llvmOperand = std::get<1>(it);
- if (!operand.getType().isa<MemRefType>() &&
- !operand.getType().isa<UnrankedMemRefType>()) {
- promotedOperands.push_back(operand);
+
+ if (operand.getType().isa<UnrankedMemRefType>()) {
+ UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+ promotedOperands);
continue;
}
- promotedOperands.push_back(
- promoteOneMemRefDescriptor(loc, llvmOperand, builder));
+ if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ MemRefDescriptor::unpack(builder, loc, llvmOperand,
+ operand.getType().cast<MemRefType>(),
+ promotedOperands);
+ continue;
+ }
+
+ promotedOperands.push_back(operand);
}
return promotedOperands;
}
@@ -2362,11 +2721,21 @@ namespace {
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
/// Creates an LLVM lowering pass.
explicit LLVMLoweringPass(bool useAlloca = false,
- bool useBarePtrCallConv = false)
- : useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv) {}
+ bool useBarePtrCallConv = false,
+ bool emitCWrappers = false)
+ : useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv),
+ emitCWrappers(emitCWrappers) {}
/// Run the dialect converter on the module.
void runOnModule() override {
+ if (useBarePtrCallConv && emitCWrappers) {
+ getModule().emitError()
+ << "incompatible conversion options: bare-pointer calling convention "
+ "and C wrapper emission";
+ signalPassFailure();
+ return;
+ }
+
ModuleOp m = getModule();
LLVM::ensureDistinctSuccessors(m);
@@ -2380,7 +2749,8 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
useAlloca);
else
- populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca);
+ populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca,
+ emitCWrappers);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
@@ -2393,19 +2763,23 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
/// Convert memrefs to bare pointers in function signatures.
bool useBarePtrCallConv;
+
+ /// Emit wrappers for C-compatible pointer-to-struct memref descriptors.
+ bool emitCWrappers;
};
} // end namespace
std::unique_ptr<OpPassBase<ModuleOp>>
-mlir::createLowerToLLVMPass(bool useAlloca) {
- return std::make_unique<LLVMLoweringPass>(useAlloca);
+mlir::createLowerToLLVMPass(bool useAlloca, bool emitCWrappers) {
+ return std::make_unique<LLVMLoweringPass>(useAlloca, emitCWrappers);
}
static PassRegistration<LLVMLoweringPass>
- pass("convert-std-to-llvm",
+ pass(PASS_NAME,
"Convert scalar and vector operations from the "
"Standard to the LLVM dialect",
[] {
return std::make_unique<LLVMLoweringPass>(
- clUseAlloca.getValue(), clUseBarePtrCallConv.getValue());
+ clUseAlloca.getValue(), clUseBarePtrCallConv.getValue(),
+ clEmitCWrappers.getValue());
});
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8396995fcb5a..7bc45cbe7fef 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -90,26 +90,27 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return launchOp.emitOpError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
+ // TODO(ntv,zinenko,herhut): if the kernel function has been converted to
+ // the LLVM dialect but the caller hasn't (which happens during the
+ // separate compilation), do not check type correspondance as it would
+ // require the verifier to be aware of the LLVM type conversion.
+ if (kernelLLVMFunction)
+ return success();
+
unsigned actualNumArguments = launchOp.getNumKernelOperands();
- unsigned expectedNumArguments = kernelLLVMFunction
- ? kernelLLVMFunction.getNumArguments()
- : kernelGPUFunction.getNumArguments();
+ unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
if (expectedNumArguments != actualNumArguments)
return launchOp.emitOpError("got ")
<< actualNumArguments << " kernel operands but expected "
<< expectedNumArguments;
- // Due to the ordering of the current impl of lowering and LLVMLowering,
- // type checks need to be temporarily disabled.
- // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
- // to encode target module" has landed.
- // auto functionType = kernelFunc.getType();
- // for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
- // if (getKernelOperand(i).getType() != functionType.getInput(i)) {
- // return emitOpError("type of function argument ")
- // << i << " does not match";
- // }
- // }
+ auto functionType = kernelGPUFunction.getType();
+ for (unsigned i = 0; i < expectedNumArguments; ++i) {
+ if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
+ return launchOp.emitOpError("type of function argument ")
+ << i << " does not match";
+ }
+ }
return success();
});
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index a4dc8977e73b..40dd16db02cf 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -401,8 +401,7 @@ Block *ArgConverter::applySignatureConversion(
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
Operation *cast = typeConverter->materializeConversion(
rewriter, origArg.getType(), replArgs, loc);
- assert(cast->getNumResults() == 1 &&
- cast->getNumOperands() == replArgs.size());
+ assert(cast->getNumResults() == 1);
mapping.map(origArg, cast->getResult(0));
info.argInfo[i] =
ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
diff --git a/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir b/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir
index 707f4a063958..bb02b5d74b53 100644
--- a/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir
+++ b/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir
@@ -6,8 +6,8 @@ module attributes {gpu.container_module} {
// CHECK: llvm.mlir.global internal constant @[[global:.*]]("CUBIN")
gpu.module @kernel_module attributes {nvvm.cubin = "CUBIN"} {
- gpu.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} {
- gpu.return
+ llvm.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} {
+ llvm.return
}
}
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir b/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir
index ca796d103ad6..f2cd0f8e694d 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir
@@ -1,25 +1,11 @@
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
-
-// CHECK-LABEL: func @check_attributes(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> {dialect.a = true, dialect.b = 4 : i64}) {
-// CHECK-NEXT: llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+// CHECK-LABEL: func @check_attributes
+// When expanding the memref to multiple arguments, argument attributes are replicated.
+// CHECK-COUNT-7: {dialect.a = true, dialect.b = 4 : i64}
func @check_attributes(%static: memref<10x20xf32> {dialect.a = true, dialect.b = 4 : i64 }) {
%c0 = constant 0 : index
%0 = load %static[%c0, %c0]: memref<10x20xf32>
return
}
-// CHECK-LABEL: func @external_func(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
-// CHECK: func @call_external(%[[arg:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
-// CHECK: %[[ld:.*]] = llvm.load %[[arg]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK: %[[alloca:.*]] = llvm.alloca %[[c1]] x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK: llvm.store %[[ld]], %[[alloca]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK: call @external_func(%[[alloca]]) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> ()
-func @external_func(memref<10x20xf32>)
-
-func @call_external(%static: memref<10x20xf32>) {
- call @external_func(%static) : (memref<10x20xf32>) -> ()
- return
-}
-
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 43cbc78bc3b1..0c8e007a6adf 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -1,14 +1,25 @@
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
-// CHECK-LABEL: func @check_strided_memref_arguments(
-// CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+// CHECK-LABEL: func @check_strided_memref_arguments(
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j)->(20 * i + j + 1)>>,
%dynamic : memref<?x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>,
%mixed : memref<10x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>) {
return
}
-// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
+// CHECK-LABEL: func @check_arguments
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) {
return
}
@@ -16,7 +27,7 @@ func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %m
// CHECK-LABEL: func @mixed_alloc(
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> {
func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
-// CHECK-NEXT: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
@@ -45,10 +56,9 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
return %0 : memref<?x42x?xf32>
}
-// CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) {
+// CHECK-LABEL: func @mixed_dealloc
func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
dealloc %arg0 : memref<?x42x?xf32>
@@ -59,7 +69,7 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK-LABEL: func @dynamic_alloc(
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
-// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
+// CHECK: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
@@ -83,10 +93,9 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
return %0 : memref<?x?xf32>
}
-// CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
+// CHECK-LABEL: func @dynamic_dealloc
func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
dealloc %arg0 : memref<?x?xf32>
@@ -94,10 +103,12 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
}
// CHECK-LABEL: func @mixed_load(
-// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">,
+// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64
+// CHECK: %[[I:.*]]: !llvm.i64,
+// CHECK: %[[J:.*]]: !llvm.i64)
func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@@ -112,10 +123,8 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
}
// CHECK-LABEL: func @dynamic_load(
-// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@@ -131,8 +140,7 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @prefetch
func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@@ -161,8 +169,7 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @dynamic_store
func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@@ -171,15 +178,14 @@ func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
+// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
store %val, %dynamic[%i, %j] : memref<?x?xf32>
return
}
// CHECK-LABEL: func @mixed_store
func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@@ -188,74 +194,66 @@ func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32)
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
+// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
store %val, %mixed[%i, %j] : memref<42x?xf32>
return
}
// CHECK-LABEL: func @memref_cast_static_to_dynamic
func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %static : memref<10x42xf32> to memref<?x?xf32>
return
}
// CHECK-LABEL: func @memref_cast_static_to_mixed
func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %static : memref<10x42xf32> to memref<?x42xf32>
return
}
// CHECK-LABEL: func @memref_cast_dynamic_to_static
func @memref_cast_dynamic_to_static(%dynamic : memref<?x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
return
}
// CHECK-LABEL: func @memref_cast_dynamic_to_mixed
func @memref_cast_dynamic_to_mixed(%dynamic : memref<?x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
return
}
// CHECK-LABEL: func @memref_cast_mixed_to_dynamic
func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %mixed : memref<42x?xf32> to memref<?x?xf32>
return
}
// CHECK-LABEL: func @memref_cast_mixed_to_static
func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32>
return
}
// CHECK-LABEL: func @memref_cast_mixed_to_mixed
func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
return
}
// CHECK-LABEL: func @memref_cast_ranked_to_unranked
func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
-// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
-// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
+// CHECK-DAG: llvm.store %{{.*}}, %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
+// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
// CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
// CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }">
// CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }">
@@ -266,19 +264,17 @@ func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
// CHECK-LABEL: func @memref_cast_unranked_to_ranked
func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) {
-// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*">
-// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }">
+// CHECK: %[[p:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i8* }">
// CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*">
%0 = memref_cast %arg : memref<*xf32> to memref<?x?x10x2xf32>
return
}
-// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
+// CHECK-LABEL: func @mixed_memref_dim
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
-// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
%0 = dim %mixed, 0 : memref<42x?x?x13x?xf32>
-// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK-NEXT: llvm.extractvalue %[[ld:.*]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
%1 = dim %mixed, 1 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
%2 = dim %mixed, 2 : memref<42x?x?x13x?xf32>
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir
index 9984f8c3766b..72d82380d979 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir
@@ -18,12 +18,12 @@ func @fifth_order_left(%arg0: (((() -> ()) -> ()) -> ()) -> ())
//CHECK: llvm.func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">)
func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ()))))
-// Check that memrefs are converted to pointers-to-struct if appear as function arguments.
-// CHECK: llvm.func @memref_call_conv(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">)
+// Check that memrefs are converted to argument packs if appear as function arguments.
+// CHECK: llvm.func @memref_call_conv(!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
func @memref_call_conv(%arg0: memref<?xf32>)
// Same in nested functions.
-// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void ({ float*, float*, i64, [1 x i64], [1 x i64] }*)*">)
+// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void (float*, float*, i64, i64, i64)*">)
func @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> {
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index e44d2fcf99bc..c25d8a235701 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -10,7 +10,10 @@ func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) {
// -----
-// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
+// CHECK-LABEL: func @check_static_return
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-LABEL: func @check_static_return
// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
@@ -76,11 +79,10 @@ func @zero_d_alloc() -> memref<f32> {
// -----
-// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) {
+// CHECK-LABEL: func @zero_d_dealloc
// BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"float*">) {
func @zero_d_dealloc(%arg0: memref<f32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
@@ -96,7 +98,7 @@ func @zero_d_dealloc(%arg0: memref<f32>) {
// CHECK-LABEL: func @aligned_1d_alloc(
// BAREPTR-LABEL: func @aligned_1d_alloc(
func @aligned_1d_alloc() -> memref<42xf32> {
-// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
@@ -150,13 +152,13 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
// BAREPTR-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @static_alloc() -> memref<32x18xf32> {
-// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
+// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
@@ -177,11 +179,10 @@ func @static_alloc() -> memref<32x18xf32> {
// -----
-// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
+// CHECK-LABEL: func @static_dealloc
// BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) {
func @static_dealloc(%static: memref<10x8xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
@@ -194,11 +195,10 @@ func @static_dealloc(%static: memref<10x8xf32>) {
// -----
-// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float {
+// CHECK-LABEL: func @zero_d_load
// BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm<"float*">) -> !llvm.float
func @zero_d_load(%arg0: memref<f32>) -> f32 {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*">
@@ -214,20 +214,22 @@ func @zero_d_load(%arg0: memref<f32>) -> f32 {
// -----
// CHECK-LABEL: func @static_load(
-// CHECK-SAME: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">,
+// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64
+// CHECK: %[[I:.*]]: !llvm.i64,
+// CHECK: %[[J:.*]]: !llvm.i64)
// BAREPTR-LABEL: func @static_load
// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) {
func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
-// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
+// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
+// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*">
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@@ -246,15 +248,14 @@ func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// -----
-// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) {
+// CHECK-LABEL: func @zero_d_store
// BAREPTR-LABEL: func @zero_d_store
// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[val:.*]]: !llvm.float)
func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*">
+// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
@@ -270,17 +271,16 @@ func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
// BAREPTR-LABEL: func @static_store
// BAREPTR-SAME: %[[A:.*]]: !llvm<"float*">
func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
-// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
-// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
+// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
+// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
+// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
@@ -298,11 +298,10 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f
// -----
-// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
+// CHECK-LABEL: func @static_memref_dim
// BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) {
func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
-// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
-// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
%0 = dim %static, 0 : memref<42x32x15x13x27xf32>
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 66a99ae5d86b..302aa31e48e0 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -728,9 +728,15 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
}
// CHECK-LABEL: func @subview(
-// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+// CHECK-COUNT-2: !llvm<"float*">,
+// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64,
+// CHECK: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i64,
+// CHECK: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i64,
+// CHECK: %[[ARG2:.*]]: !llvm.i64)
func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+ // The last "insertvalue" that populates the memref descriptor from the function arguments.
+ // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
+
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@@ -754,9 +760,10 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg
}
// CHECK-LABEL: func @subview_const_size(
-// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+ // The last "insertvalue" that populates the memref descriptor from the function arguments.
+ // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
+
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@@ -782,9 +789,10 @@ func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 +
}
// CHECK-LABEL: func @subview_const_stride(
-// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+ // The last "insertvalue" that populates the memref descriptor from the function arguments.
+ // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
+
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index e020c8f6931c..fd574d7e67d9 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -1,8 +1,7 @@
// RUN: mlir-opt %s -convert-std-to-llvm -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @address_space(
-// CHECK: %{{.*}}: !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">)
-// CHECK: llvm.load %{{.*}} : !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">
+// CHECK-SAME: !llvm<"float addrspace(7)*">
func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
%0 = alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5>
%1 = constant 7 : index
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 415990a47b68..341af2dabf68 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -175,23 +175,21 @@ module attributes {gpu.container_module} {
// -----
-gpu.module @kernels {
- gpu.func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
- gpu.return
+module attributes {gpu.container_module} {
+ gpu.module @kernels {
+ gpu.func @kernel_1(%arg1 : f32) attributes { gpu.kernel } {
+ gpu.return
+ }
}
-}
-// Due to the ordering of the current impl of lowering and LLVMLowering, type
-// checks need to be temporarily disabled.
-// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
-// to encode target module" has landed.
-// func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
-// // expected-err at +1 {{type of function argument 0 does not match}}
-// "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
-// {kernel = "kernel_1"}
-// : (index, index, index, index, index, index, f32) -> ()
-// return
-// }
+ func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
+ // expected-err at +1 {{type of function argument 0 does not match}}
+ "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
+ {kernel = "kernel_1", kernel_module = @kernels}
+ : (index, index, index, index, index, index, f32) -> ()
+ return
+ }
+}
// -----
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 4a5f15b319b4..82ec950584d0 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -52,9 +52,11 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
return
}
-// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64 }*">) {
-// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
-// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64 }*">) -> ()
+// CHECK-LABEL: func @dot
+// CHECK: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}) :
+// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64
func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?, 1]>) {
%c0 = constant 0 : index
@@ -83,7 +85,9 @@ func @copy(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memre
return
}
// CHECK-LABEL: func @copy
-// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> ()
+// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32({{.*}}) :
+// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
@@ -128,9 +132,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// Call external copy after promoting input and output structs to pointers
-// CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
-// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> ()
+// Call external copy.
+// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32
#matmul_accesses = [
affine_map<(m, n, k) -> (m, k)>,
@@ -163,7 +166,10 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
return
}
// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
+// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) :
+// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
@@ -195,7 +201,10 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
}
// CHECK-LABEL: func @matmul_vec_indexed(
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
+// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}) :
+// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
+// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
func @reshape_static(%arg0: memref<3x4x5xf32>) {
// Reshapes that expand and collapse back a contiguous tensor with some 1's.
diff --git a/mlir/test/mlir-cpu-runner/cblas_interface.cpp b/mlir/test/mlir-cpu-runner/cblas_interface.cpp
index b6d212070e2b..b1fe5a6a1470 100644
--- a/mlir/test/mlir-cpu-runner/cblas_interface.cpp
+++ b/mlir/test/mlir-cpu-runner/cblas_interface.cpp
@@ -15,32 +15,35 @@
#include <assert.h>
#include <iostream>
-extern "C" void linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X,
- float f) {
+extern "C" void
+_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
X->data[X->offset] = f;
}
-extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
- float f) {
+extern "C" void
+_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
+ float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
*(X->data + X->offset + i * X->strides[0]) = f;
}
-extern "C" void linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
- float f) {
+extern "C" void
+_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
+ float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
for (unsigned j = 0; j < X->sizes[1]; ++j)
*(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
}
-extern "C" void linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
- StridedMemRefType<float, 0> *O) {
+extern "C" void
+_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
+ StridedMemRefType<float, 0> *O) {
O->data[O->offset] = I->data[I->offset];
}
extern "C" void
-linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
- StridedMemRefType<float, 1> *O) {
+_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
+ StridedMemRefType<float, 1> *O) {
if (I->sizes[0] != O->sizes[0]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *I);
@@ -52,9 +55,8 @@ linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
I->data[I->offset + i * I->strides[0]];
}
-extern "C" void
-linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
- StridedMemRefType<float, 2> *O) {
+extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
+ StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *I);
@@ -69,10 +71,9 @@ linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
I->data[I->offset + i * si0 + j * si1];
}
-extern "C" void
-linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
- StridedMemRefType<float, 1> *Y,
- StridedMemRefType<float, 0> *Z) {
+extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
+ StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
+ StridedMemRefType<float, 0> *Z) {
if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *X);
@@ -85,7 +86,7 @@ linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
Y->data + Y->offset, Y->strides[0]);
}
-extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
+extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C) {
if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
diff --git a/mlir/test/mlir-cpu-runner/include/cblas_interface.h b/mlir/test/mlir-cpu-runner/include/cblas_interface.h
index c04861cf047f..83292208aa61 100644
--- a/mlir/test/mlir-cpu-runner/include/cblas_interface.h
+++ b/mlir/test/mlir-cpu-runner/include/cblas_interface.h
@@ -25,33 +25,34 @@
#endif // _WIN32
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
+_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
+_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X, float f);
+_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
+ float f);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
- StridedMemRefType<float, 0> *O);
+_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
+ StridedMemRefType<float, 0> *O);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
- StridedMemRefType<float, 1> *O);
+_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
+ StridedMemRefType<float, 1> *O);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
- StridedMemRefType<float, 2> *O);
+_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
+ StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
- StridedMemRefType<float, 1> *Y,
- StridedMemRefType<float, 0> *Z);
+_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
+ StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
+ StridedMemRefType<float, 0> *Z);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
-linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
+_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C);
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
index b637decf7767..a10317a49399 100644
--- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
+++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
@@ -261,23 +261,27 @@ template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
// Currently exposed C API.
////////////////////////////////////////////////////////////////////////////////
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_i8(UnrankedMemRefType<int8_t> *M);
+_mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_f32(UnrankedMemRefType<float> *M);
+_mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
+
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(int64_t rank,
+ void *ptr);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_0d_f32(StridedMemRefType<float, 0> *M);
+_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_1d_f32(StridedMemRefType<float, 1> *M);
+_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_2d_f32(StridedMemRefType<float, 2> *M);
+_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_3d_f32(StridedMemRefType<float, 3> *M);
+_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_4d_f32(StridedMemRefType<float, 4> *M);
+_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
-print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
+_mlir_ciface_print_memref_vector_4x4xf32(
+ StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
// Small runtime support "lib" for vector.print lowering.
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f);
diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
index 9225eab00a06..984d29dfd2ca 100644
--- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
+++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
@@ -16,8 +16,8 @@
#include <cinttypes>
#include <cstdio>
-extern "C" void
-print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
+extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
+ StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
impl::printMemRef(*M);
}
@@ -26,7 +26,7 @@ print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
impl::printMemRef(*(static_cast<StridedMemRefType<TYPE, RANK> *>(ptr))); \
break
-extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
+extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
printUnrankedMemRefMetaData(std::cout, *M);
int rank = M->rank;
void *ptr = M->descriptor;
@@ -42,7 +42,7 @@ extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
}
}
-extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
+extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
printUnrankedMemRefMetaData(std::cout, *M);
int rank = M->rank;
void *ptr = M->descriptor;
@@ -58,19 +58,31 @@ extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
}
}
-extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
+extern "C" void print_memref_f32(int64_t rank, void *ptr) {
+ UnrankedMemRefType<float> descriptor;
+ descriptor.rank = rank;
+ descriptor.descriptor = ptr;
+ _mlir_ciface_print_memref_f32(&descriptor);
+}
+
+extern "C" void
+_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
impl::printMemRef(*M);
}
-extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
+extern "C" void
+_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
impl::printMemRef(*M);
}
-extern "C" void print_memref_2d_f32(StridedMemRefType<float, 2> *M) {
+extern "C" void
+_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M) {
impl::printMemRef(*M);
}
-extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
+extern "C" void
+_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
impl::printMemRef(*M);
}
-extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
+extern "C" void
+_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
impl::printMemRef(*M);
}
diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
index c6010eb099d8..bf882761f148 100644
--- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
+++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
@@ -17,12 +17,13 @@ func @main() {
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref<?xf32>) -> ()
- call @print_memref_1d_f32(%22) : (memref<?xf32>) -> ()
+ %23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+ call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
- call @print_memref_1d_f32(%22) : (memref<?xf32>) -> ()
+ call @print_memref_f32(%23) : (memref<*xf32>) -> ()
return
}
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
-func @print_memref_1d_f32(memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index a4076da692c3..350d9869373a 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -96,11 +96,34 @@ void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
std::fill_n(arg->data, count, value);
mcuMemHostRegister(arg->data, count * sizeof(T));
}
-extern "C" void
-mcuMemHostRegisterMemRef1dFloat(const MemRefType<float, 1> *arg) {
- mcuMemHostRegisterMemRef(arg, 1.23f);
+
+extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
+ float *aligned, int64_t offset,
+ int64_t size, int64_t stride) {
+ MemRefType<float, 1> descriptor;
+ descriptor.basePtr = allocated;
+ descriptor.data = aligned;
+ descriptor.offset = offset;
+ descriptor.sizes[0] = size;
+ descriptor.strides[0] = stride;
+ mcuMemHostRegisterMemRef(&descriptor, 1.23f);
}
-extern "C" void
-mcuMemHostRegisterMemRef3dFloat(const MemRefType<float, 3> *arg) {
- mcuMemHostRegisterMemRef(arg, 1.23f);
+
+extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
+ float *aligned, int64_t offset,
+ int64_t size0, int64_t size1,
+ int64_t size2, int64_t stride0,
+ int64_t stride1,
+ int64_t stride2) {
+ MemRefType<float, 3> descriptor;
+ descriptor.basePtr = allocated;
+ descriptor.data = aligned;
+ descriptor.offset = offset;
+ descriptor.sizes[0] = size0;
+ descriptor.strides[0] = stride0;
+ descriptor.sizes[1] = size1;
+ descriptor.strides[1] = stride1;
+ descriptor.sizes[2] = size2;
+ descriptor.strides[2] = stride2;
+ mcuMemHostRegisterMemRef(&descriptor, 1.23f);
}
More information about the Mlir-commits
mailing list