[Mlir-commits] [mlir] 7ec88f0 - [mlir][memref][transform] Add vector_to_llvm conversion patterns
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 9 02:28:05 PDT 2023
Author: Matthias Springer
Date: 2023-08-09T11:27:53+02:00
New Revision: 7ec88f06d5833dfb4c7029c7645ae6cb89520504
URL: https://github.com/llvm/llvm-project/commit/7ec88f06d5833dfb4c7029c7645ae6cb89520504
DIFF: https://github.com/llvm/llvm-project/commit/7ec88f06d5833dfb4c7029c7645ae6cb89520504.diff
LOG: [mlir][memref][transform] Add vector_to_llvm conversion patterns
These patterns are exposed via a new "apply_conversion_patterns" op.
Also provide a new type converter that converts from memref to LLVM types. Conversion patterns that lower to LLVM are special: they require an `LLVMTypeConverter`; a normal `TypeConverter` is not enough. This revision also adds a new interface method to pattern descriptor ops to verify that the default type converter of the enclosing "apply_conversion_patterns" op is compatible with the set of patterns. At the moment, a simple `StringRef` is used. This can evolve to a richer type in the future if needed.
Differential Revision: https://reviews.llvm.org/D157369
Added:
mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 86a3586bcc58ba..243ce16b019dee 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -15,6 +15,33 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
+ "apply_conversion_patterns.memref.memref_to_llvm_type_converter",
+ [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+ ["getTypeConverterType"]>]> {
+ let description = [{
+ This operation provides an "LLVMTypeConverter" that lowers memref types to
+ LLVM types.
+
+ The type converter can be customized as follows:
+ - `use_aligned_alloc`: Use aligned_alloc in place of malloc for heap
+ allocations.
+ - `index_bitwidth`: Bitwidth of the index type, "0" indicates the size of a
+ machine word.
+ - `use_generic_functions`: Use generic allocation and deallocation functions
+ instead of the classic "malloc", "aligned_alloc" and "free" functions.
+ - `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of
+ typed pointers.
+ }];
+
+ let arguments = (ins
+ DefaultValuedAttr<BoolAttr, "false">:$use_aligned_alloc,
+ DefaultValuedAttr<I64Attr, "0">:$index_bitwidth,
+ DefaultValuedAttr<BoolAttr, "false">:$use_generic_functions,
+ DefaultValuedAttr<BoolAttr, "false">:$use_opaque_pointers);
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyExpandOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.memref.expand_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 5ce21be223bcb2..d40d780e73c554 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -263,6 +263,39 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
];
}
+def TypeConverterBuilderOpInterface
+ : OpInterface<"TypeConverterBuilderOpInterface"> {
+ let description = [{
+ This interface should be implemented by ops that specify a type converter
+ for a dialect conversion. Such ops can be used with
+ "apply_conversion_patterns".
+ }];
+
+ let cppNamespace = "::mlir::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the type converter to be used with a dialect conversion.
+ }],
+ /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
+ /*name=*/"getTypeConverter",
+ /*arguments=*/(ins)
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/[{
+ Return the type of type converter that this `getTypeConverter` returns.
+ This function is used for op verification.
+ }],
+ /*returnType=*/"StringRef",
+ /*name=*/"getTypeConverterType",
+ /*arguments=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return "TypeConverter"; }]
+ >,
+ ];
+}
+
def ConversionPatternDescriptorOpInterface
: OpInterface<"ConversionPatternDescriptorOpInterface"> {
let description = [{
@@ -300,27 +333,16 @@ def ConversionPatternDescriptorOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/"return nullptr;"
>,
- ];
-}
-
-def TypeConverterBuilderOpInterface
- : OpInterface<"TypeConverterBuilderOpInterface"> {
- let description = [{
- This interface should be implemented by ops that specify a type converter
- for a dialect conversion. Such ops can be used with
- "apply_conversion_patterns".
- }];
-
- let cppNamespace = "::mlir::transform";
-
- let methods = [
InterfaceMethod<
/*desc=*/[{
- Return the type converter to be used with a dialect conversion.
+ Verify the default type converter that is provided by the enclosing
+ "apply_conversion_patterns" op.
}],
- /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
- /*name=*/"getTypeConverter",
- /*arguments=*/(ins)
+ /*returnType=*/"::mlir::LogicalResult",
+ /*name=*/"verifyTypeConverter",
+ /*arguments=*/(ins "TypeConverterBuilderOpInterface":$builder),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return success();"
>,
];
}
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index e3d27cb4c71690..2b8c95a94257e6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -15,6 +15,28 @@ include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
+ "apply_conversion_patterns.vector.vector_to_llvm",
+ [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
+ ["verifyTypeConverter"]>]> {
+ let description = [{
+ Collects patterns that convert vector dialect ops to LLVM dialect ops. These
+ patterns require an "LLVMTypeConverter".
+
+ The patterns can be customized as follows:
+ - `reassociate_fp_reductions`: Allows LLVM to reassociate floating-point
+ reductions for speed.
+ - `force_32bit_vector_indices`: Allows the compiler to assume that vector
+ indices fit in 32-bit if that yields faster code.
+ }];
+
+ let arguments = (ins
+ DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
+ DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+ let assemblyFormat = "attr-dict";
+}
+
+
def ApplyCastAwayVectorLeadingOneDimPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.cast_away_vector_leading_one_dim",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 671a4bd868ff81..559c0c00013d45 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -37,6 +37,8 @@ class Value;
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
+ virtual ~TypeConverter() = default;
+
/// This class provides all of the information necessary to convert a type
/// signature.
class SignatureConversion {
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
index 0c9ee3d00afe69..e379663f3b0a01 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -11,6 +11,8 @@ add_mlir_dialect_library(MLIRMemRefTransformOps
MLIRAffineDialect
MLIRArithDialect
MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRMemRefTransforms
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index a46f2a37cb1d28..af067c0453333a 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -26,6 +28,29 @@ using namespace mlir;
#define DEBUG_TYPE "memref-transforms"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+//===----------------------------------------------------------------------===//
+// Apply...ConversionPatternsOp
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<TypeConverter>
+transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
+ LowerToLLVMOptions options(getContext());
+ options.allocLowering =
+ (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
+ : LowerToLLVMOptions::AllocLowering::Malloc);
+ options.useGenericFunctions = getUseGenericFunctions();
+ options.useOpaquePointers = getUseOpaquePointers();
+
+ if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
+ options.overrideIndexBitwidth(getIndexBitwidth());
+
+ return std::make_unique<LLVMTypeConverter>(getContext(), options);
+}
+
+StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
+ return "LLVMTypeConverter";
+}
+
//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index e6baf470199837..70a8c1c7eeab84 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -589,14 +589,24 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
if (!llvm::hasSingleElement(typeConverterRegion.front()))
return emitOpError()
<< "expected exactly one op in default type converter region";
- Operation *typeConverterOp = &typeConverterRegion.front().front();
- if (!isa<transform::TypeConverterBuilderOpInterface>(typeConverterOp)) {
+ auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
+ &typeConverterRegion.front().front());
+ if (!typeConverterOp) {
InFlightDiagnostic diag = emitOpError()
<< "expected default converter child op to "
"implement TypeConverterBuilderOpInterface";
diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
return diag;
}
+ // Check default type converter type.
+ if (!getPatterns().empty()) {
+ for (Operation &op : getPatterns().front()) {
+ auto descriptor =
+ cast<transform::ConversionPatternDescriptorOpInterface>(&op);
+ if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
+ return failure();
+ }
+ }
}
if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() &&
!getIllegalDialects())
diff --git a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
index 9cdbdfcfaf6c82..b9cedf5efae659 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
@@ -9,7 +9,10 @@ add_mlir_dialect_library(MLIRVectorTransformOps
LINK_LIBS PUBLIC
MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
MLIRVectorDialect
+ MLIRVectorToLLVM
MLIRVectorTransforms
MLIRSideEffectInterfaces
MLIRTransformDialect
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 8572b6df75bace..94f19e59669eaf 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -7,6 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -23,6 +26,25 @@ using namespace mlir;
using namespace mlir::vector;
using namespace mlir::transform;
+//===----------------------------------------------------------------------===//
+// Apply...ConversionPatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ populateVectorToLLVMConversionPatterns(
+ static_cast<LLVMTypeConverter &>(typeConverter), patterns,
+ getReassociateFpReductions(), getForce_32bitVectorIndices());
+}
+
+LogicalResult
+transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
+ transform::TypeConverterBuilderOpInterface builder) {
+ if (builder.getTypeConverterType() != "LLVMTypeConverter")
+ return emitOpError("expected LLVMTypeConverter");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir
new file mode 100644
index 00000000000000..73eb45470ab798
--- /dev/null
+++ b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @lower_to_llvm
+// CHECK-NOT: vector.bitcast
+// CHECK: llvm.bitcast
+func.func @lower_to_llvm(%input: vector<f32>) -> vector<i32> {
+ %0 = vector.bitcast %input : vector<f32> to vector<i32>
+ return %0 : vector<i32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_conversion_patterns to %0 {
+ transform.apply_conversion_patterns.vector.vector_to_llvm
+ } with type_converter {
+ transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+ } {legal_dialects = ["func", "llvm"]} : !transform.any_op
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index fa5285e5645802..adb4a59f370da0 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4146,12 +4146,14 @@ cc_library(
":ArithDialect",
":AsmParser",
":IR",
+ ":LLVMCommonConversion",
":LLVMDialect",
":SideEffectInterfaces",
":TransformDialect",
":TransformUtils",
":VectorDialect",
":VectorEnumsIncGen",
+ ":VectorToLLVM",
":VectorToSCF",
":VectorTransformOpsIncGen",
":VectorTransforms",
@@ -11510,6 +11512,7 @@ cc_library(
":AffineDialect",
":ArithDialect",
":IR",
+ ":LLVMCommonConversion",
":LoopLikeInterface",
":MemRefDialect",
":MemRefTransformOpsIncGen",
More information about the Mlir-commits
mailing list