[Mlir-commits] [mlir] 7ec88f0 - [mlir][memref][transform] Add vector_to_llvm conversion patterns

Matthias Springer llvmlistbot at llvm.org
Wed Aug 9 02:28:05 PDT 2023


Author: Matthias Springer
Date: 2023-08-09T11:27:53+02:00
New Revision: 7ec88f06d5833dfb4c7029c7645ae6cb89520504

URL: https://github.com/llvm/llvm-project/commit/7ec88f06d5833dfb4c7029c7645ae6cb89520504
DIFF: https://github.com/llvm/llvm-project/commit/7ec88f06d5833dfb4c7029c7645ae6cb89520504.diff

LOG: [mlir][memref][transform] Add vector_to_llvm conversion patterns

These patterns are exposed via a new "apply_conversion_patterns" op.

Also provide a new type converter that converts from memref to LLVM types. Conversion patterns that lower to LLVM are special: they require an `LLVMTypeConverter`; a normal `TypeConverter` is not enough. This revision also adds a new interface method to pattern descriptor ops to verify that the default type converter of the enclosing "apply_conversion_patterns" op is compatible with the set of patterns. At the moment, a simple `StringRef` is used. This can evolve to a richer type in the future if needed.

Differential Revision: https://reviews.llvm.org/D157369

Added: 
    mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 86a3586bcc58ba..243ce16b019dee 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -15,6 +15,33 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
+    "apply_conversion_patterns.memref.memref_to_llvm_type_converter",
+    [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                               ["getTypeConverterType"]>]> {
+  let description = [{
+    This operation provides an "LLVMTypeConverter" that lowers memref types to
+    LLVM types.
+
+    The type converter can be customized as follows:
+    - `use_aligned_alloc`: Use aligned_alloc in place of malloc for heap
+      allocations.
+    - `index_bitwidth`: Bitwidth of the index type, "0" indicates the size of a
+      machine word.
+    - `use_generic_functions`: Use generic allocation and deallocation functions
+      instead of the classic "malloc", "aligned_alloc" and "free" functions.
+    - `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of
+      typed pointers.
+  }];
+
+  let arguments = (ins
+      DefaultValuedAttr<BoolAttr, "false">:$use_aligned_alloc,
+      DefaultValuedAttr<I64Attr, "0">:$index_bitwidth,
+      DefaultValuedAttr<BoolAttr, "false">:$use_generic_functions,
+      DefaultValuedAttr<BoolAttr, "false">:$use_opaque_pointers);
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyExpandOpsPatternsOp : Op<Transform_Dialect,
     "apply_patterns.memref.expand_ops",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 5ce21be223bcb2..d40d780e73c554 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -263,6 +263,39 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   ];
 }
 
+def TypeConverterBuilderOpInterface
+    : OpInterface<"TypeConverterBuilderOpInterface"> {
+  let description = [{
+    This interface should be implemented by ops that specify a type converter
+    for a dialect conversion. Such ops can be used with
+    "apply_conversion_patterns".
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the type converter to be used with a dialect conversion.
+      }],
+      /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
+      /*name=*/"getTypeConverter",
+      /*arguments=*/(ins)
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/[{
+        Return the type of type converter that this `getTypeConverter` returns.
+        This function is used for op verification.
+      }],
+      /*returnType=*/"StringRef",
+      /*name=*/"getTypeConverterType",
+      /*arguments=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return "TypeConverter"; }]
+    >,
+  ];
+}
+
 def ConversionPatternDescriptorOpInterface
     : OpInterface<"ConversionPatternDescriptorOpInterface"> {
   let description = [{
@@ -300,27 +333,16 @@ def ConversionPatternDescriptorOpInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/"return nullptr;"
     >,
-  ];
-}
-
-def TypeConverterBuilderOpInterface
-    : OpInterface<"TypeConverterBuilderOpInterface"> {
-  let description = [{
-    This interface should be implemented by ops that specify a type converter
-    for a dialect conversion. Such ops can be used with
-    "apply_conversion_patterns".
-  }];
-
-  let cppNamespace = "::mlir::transform";
-
-  let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Return the type converter to be used with a dialect conversion.
+        Verify the default type converter that is provided by the enclosing
+        "apply_conversion_patterns" op.
       }],
-      /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
-      /*name=*/"getTypeConverter",
-      /*arguments=*/(ins)
+      /*returnType=*/"::mlir::LogicalResult",
+      /*name=*/"verifyTypeConverter",
+      /*arguments=*/(ins "TypeConverterBuilderOpInterface":$builder),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return success();"
     >,
   ];
 }

diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index e3d27cb4c71690..2b8c95a94257e6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -15,6 +15,28 @@ include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
+    "apply_conversion_patterns.vector.vector_to_llvm",
+    [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
+                               ["verifyTypeConverter"]>]> {
+  let description = [{
+    Collects patterns that convert vector dialect ops to LLVM dialect ops. These
+    patterns require an "LLVMTypeConverter".
+
+    The patterns can be customized as follows:
+    - `reassociate_fp_reductions`: Allows LLVM to reassociate floating-point
+      reductions for speed.
+    - `force_32bit_vector_indices`: Allows the compiler to assume that vector
+      indices fit in 32-bit if that yields faster code.
+  }];
+
+  let arguments = (ins
+      DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
+      DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+  let assemblyFormat = "attr-dict";
+}
+
+
 def ApplyCastAwayVectorLeadingOneDimPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.cast_away_vector_leading_one_dim",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 671a4bd868ff81..559c0c00013d45 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -37,6 +37,8 @@ class Value;
 /// registered using addConversion and addMaterialization, respectively.
 class TypeConverter {
 public:
+  virtual ~TypeConverter() = default;
+
   /// This class provides all of the information necessary to convert a type
   /// signature.
   class SignatureConversion {

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
index 0c9ee3d00afe69..e379663f3b0a01 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -11,6 +11,8 @@ add_mlir_dialect_library(MLIRMemRefTransformOps
   MLIRAffineDialect
   MLIRArithDialect
   MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
   MLIRLoopLikeInterface
   MLIRMemRefDialect
   MLIRMemRefTransforms

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index a46f2a37cb1d28..af067c0453333a 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -26,6 +28,29 @@ using namespace mlir;
 #define DEBUG_TYPE "memref-transforms"
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 
+//===----------------------------------------------------------------------===//
+// Apply...ConversionPatternsOp
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<TypeConverter>
+transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
+  LowerToLLVMOptions options(getContext());
+  options.allocLowering =
+      (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
+                            : LowerToLLVMOptions::AllocLowering::Malloc);
+  options.useGenericFunctions = getUseGenericFunctions();
+  options.useOpaquePointers = getUseOpaquePointers();
+
+  if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
+    options.overrideIndexBitwidth(getIndexBitwidth());
+
+  return std::make_unique<LLVMTypeConverter>(getContext(), options);
+}
+
+StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
+  return "LLVMTypeConverter";
+}
+
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index e6baf470199837..70a8c1c7eeab84 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -589,14 +589,24 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
     if (!llvm::hasSingleElement(typeConverterRegion.front()))
       return emitOpError()
              << "expected exactly one op in default type converter region";
-    Operation *typeConverterOp = &typeConverterRegion.front().front();
-    if (!isa<transform::TypeConverterBuilderOpInterface>(typeConverterOp)) {
+    auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
+        &typeConverterRegion.front().front());
+    if (!typeConverterOp) {
       InFlightDiagnostic diag = emitOpError()
                                 << "expected default converter child op to "
                                    "implement TypeConverterBuilderOpInterface";
       diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
       return diag;
     }
+    // Check default type converter type.
+    if (!getPatterns().empty()) {
+      for (Operation &op : getPatterns().front()) {
+        auto descriptor =
+            cast<transform::ConversionPatternDescriptorOpInterface>(&op);
+        if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
+          return failure();
+      }
+    }
   }
   if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() &&
       !getIllegalDialects())

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
index 9cdbdfcfaf6c82..b9cedf5efae659 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
@@ -9,7 +9,10 @@ add_mlir_dialect_library(MLIRVectorTransformOps
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
   MLIRVectorDialect
+  MLIRVectorToLLVM
   MLIRVectorTransforms
   MLIRSideEffectInterfaces
   MLIRTransformDialect

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 8572b6df75bace..94f19e59669eaf 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -7,6 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -23,6 +26,25 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::transform;
 
+//===----------------------------------------------------------------------===//
+// Apply...ConversionPatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  populateVectorToLLVMConversionPatterns(
+      static_cast<LLVMTypeConverter &>(typeConverter), patterns,
+      getReassociateFpReductions(), getForce_32bitVectorIndices());
+}
+
+LogicalResult
+transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
+    transform::TypeConverterBuilderOpInterface builder) {
+  if (builder.getTypeConverterType() != "LLVMTypeConverter")
+    return emitOpError("expected LLVMTypeConverter");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir
new file mode 100644
index 00000000000000..73eb45470ab798
--- /dev/null
+++ b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @lower_to_llvm
+//   CHECK-NOT:   vector.bitcast
+//       CHECK:   llvm.bitcast
+func.func @lower_to_llvm(%input: vector<f32>) -> vector<i32> {
+  %0 = vector.bitcast %input : vector<f32> to vector<i32>
+  return %0 : vector<i32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_conversion_patterns to %0 {
+    transform.apply_conversion_patterns.vector.vector_to_llvm
+  } with type_converter {
+    transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+  } {legal_dialects = ["func", "llvm"]} : !transform.any_op
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index fa5285e5645802..adb4a59f370da0 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4146,12 +4146,14 @@ cc_library(
         ":ArithDialect",
         ":AsmParser",
         ":IR",
+        ":LLVMCommonConversion",
         ":LLVMDialect",
         ":SideEffectInterfaces",
         ":TransformDialect",
         ":TransformUtils",
         ":VectorDialect",
         ":VectorEnumsIncGen",
+        ":VectorToLLVM",
         ":VectorToSCF",
         ":VectorTransformOpsIncGen",
         ":VectorTransforms",
@@ -11510,6 +11512,7 @@ cc_library(
         ":AffineDialect",
         ":ArithDialect",
         ":IR",
+        ":LLVMCommonConversion",
         ":LoopLikeInterface",
         ":MemRefDialect",
         ":MemRefTransformOpsIncGen",


        


More information about the Mlir-commits mailing list