[Mlir-commits] [mlir] 987fbae - [mlir] StandardToLLVM: make one-to-one convresion pattern publicly available

Alex Zinenko llvmlistbot at llvm.org
Thu Mar 26 10:24:37 PDT 2020


Author: Alex Zinenko
Date: 2020-03-26T18:24:07+01:00
New Revision: 987fbae0add347b72a1d705ca6349e39549a25d5

URL: https://github.com/llvm/llvm-project/commit/987fbae0add347b72a1d705ca6349e39549a25d5
DIFF: https://github.com/llvm/llvm-project/commit/987fbae0add347b72a1d705ca6349e39549a25d5.diff

LOG: [mlir] StandardToLLVM: make one-to-one convresion pattern publicly available

Summary:
The Standard-to-LLVM dialect convresion has a set of utility classes that
simplify conversions, including patterns that provide one-to-one conversion
operation conversion with optional result packing. Expose these classes in a
public header so that conversions other than Standard-to-LLVM (e.g. vectors, or
LLVM-based intrinsics) could also use them. Since the patterns are implemented
as class templates and in order to keep the code size limited, keep the
implementation private by resorting to op identifiers instead of template-based
builders.

Differential Revision: https://reviews.llvm.org/D76864

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 2e6a513f10db..95da9805606b 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -90,6 +90,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// to each of the MLIR types converted with `convertType`.
   Type packFunctionResults(ArrayRef<Type> types);
 
+  /// Returns the MLIR context.
+  MLIRContext &getContext();
+
   /// Returns the LLVM context.
   llvm::LLVMContext &getLLVMContext();
 
@@ -356,6 +359,7 @@ class UnrankedMemRefDescriptor : public StructBuilder {
   /// `unpack`.
   static unsigned getNumUnpackedValues() { return 2; }
 };
+
 /// Base class for operation conversions targeting the LLVM IR dialect. Provides
 /// conversion patterns with access to an LLVMTypeConverter.
 class ConvertToLLVMPattern : public ConversionPattern {
@@ -364,11 +368,79 @@ class ConvertToLLVMPattern : public ConversionPattern {
                        LLVMTypeConverter &typeConverter,
                        PatternBenefit benefit = 1);
 
+  /// Returns the LLVM dialect.
+  LLVM::LLVMDialect &getDialect() const;
+
+  /// Returns the LLVM IR context.
+  llvm::LLVMContext &getContext() const;
+
+  /// Returns the LLVM IR module associated with the LLVM dialect.
+  llvm::Module &getModule() const;
+
+  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
+  /// defined by the pointer size used in the LLVM module.
+  LLVM::LLVMType getIndexType() const;
+
+  /// Gets the MLIR type wrapping the LLVM void type.
+  LLVM::LLVMType getVoidType() const;
+
+  /// Get the MLIR type wrapping the LLVM i8* type.
+  LLVM::LLVMType getVoidPtrType() const;
+
+  /// Create an LLVM dialect operation defining the given index constant.
+  Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
+                            uint64_t value) const;
+
 protected:
   /// Reference to the type converter, with potential extensions.
   LLVMTypeConverter &typeConverter;
 };
 
+/// Utility class for operation conversions targeting the LLVM dialect that
+/// match exactly one source operation.
+template <typename OpTy>
+class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
+public:
+  ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
+                         PatternBenefit benefit = 1)
+      : ConvertToLLVMPattern(OpTy::getOperationName(),
+                             &typeConverter.getContext(), typeConverter,
+                             benefit) {}
+};
+
+namespace LLVM {
+namespace detail {
+/// Replaces the given operaiton "op" with a new operation of type "targetOp"
+/// and given operands.
+LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
+                              ValueRange operands,
+                              LLVMTypeConverter &typeConverter,
+                              ConversionPatternRewriter &rewriter);
+} // namespace detail
+} // namespace LLVM
+
+/// Generic implementation of one-to-one conversion from "SourceOp" to
+/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
+/// Upholds a convention that multi-result operations get converted into an
+/// operation returning the LLVM IR structure type, in which case individual
+/// values must be extacted from using LLVM::ExtractValueOp before being used.
+template <typename SourceOp, typename TargetOp>
+class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+public:
+  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+  using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
+
+  /// Converts the type of the result to an LLVM type, pass operands as is,
+  /// preserve attributes.
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
+                                         operands, this->typeConverter,
+                                         rewriter);
+  }
+};
+
 /// Derived class that automatically populates legalization information for
 /// 
diff erent LLVM ops.
 class LLVMConversionTarget : public ConversionTarget {

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index d37a7733a713..474a4f08b9f6 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -148,6 +148,11 @@ LLVMTypeConverter::LLVMTypeConverter(
   addConversion([](LLVM::LLVMType type) { return type; });
 }
 
+/// Returns the MLIR context.
+MLIRContext &LLVMTypeConverter::getContext() {
+  return *getDialect()->getContext();
+}
+
 /// Get the LLVM context.
 llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
   return module->getContext();
@@ -699,52 +704,35 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
   results.push_back(d.memRefDescPtr(builder, loc));
 }
 
-namespace {
-// Base class for Standard to LLVM IR op conversions.  Matches the Op type
-// provided as template argument.  Carries a reference to the LLVM dialect in
-// case it is necessary for rewriters.
-template <typename SourceOp>
-class LLVMLegalizationPattern : public ConvertToLLVMPattern {
-public:
-  // Construct a conversion pattern.
-  explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
-                                   LLVMTypeConverter &typeConverter_)
-      : ConvertToLLVMPattern(SourceOp::getOperationName(),
-                             dialect_.getContext(), typeConverter_),
-        dialect(dialect_) {}
-
-  // Get the LLVM IR dialect.
-  LLVM::LLVMDialect &getDialect() const { return dialect; }
-  // Get the LLVM context.
-  llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
-  // Get the LLVM module in which the types are constructed.
-  llvm::Module &getModule() const { return dialect.getLLVMModule(); }
-
-  // Get the MLIR type wrapping the LLVM integer type whose bit width is defined
-  // by the pointer size used in the LLVM module.
-  LLVM::LLVMType getIndexType() const {
-    return LLVM::LLVMType::getIntNTy(
-        &dialect, getModule().getDataLayout().getPointerSizeInBits());
-  }
+LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
+  return *typeConverter.getDialect();
+}
 
-  LLVM::LLVMType getVoidType() const {
-    return LLVM::LLVMType::getVoidTy(&dialect);
-  }
+llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
+  return typeConverter.getLLVMContext();
+}
 
-  // Get the MLIR type wrapping the LLVM i8* type.
-  LLVM::LLVMType getVoidPtrType() const {
-    return LLVM::LLVMType::getInt8PtrTy(&dialect);
-  }
+llvm::Module &ConvertToLLVMPattern::getModule() const {
+  return getDialect().getLLVMModule();
+}
 
-  // Create an LLVM IR pseudo-operation defining the given index constant.
-  Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
-                            uint64_t value) const {
-    return createIndexAttrConstant(builder, loc, getIndexType(), value);
-  }
+LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
+  return LLVM::LLVMType::getIntNTy(
+      &getDialect(), getModule().getDataLayout().getPointerSizeInBits());
+}
 
-protected:
-  LLVM::LLVMDialect &dialect;
-};
+LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
+  return LLVM::LLVMType::getVoidTy(&getDialect());
+}
+
+LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
+  return LLVM::LLVMType::getInt8PtrTy(&getDialect());
+}
+
+Value ConvertToLLVMPattern::createIndexConstant(
+    ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
+  return createIndexAttrConstant(builder, loc, getIndexType(), value);
+}
 
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
@@ -876,9 +864,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   builder.create<LLVM::ReturnOp>(loc, call.getResults());
 }
 
-struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
+namespace {
+
+struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
 protected:
-  using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
+  using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
   using UnsignedTypePair = std::pair<unsigned, Type>;
 
   // Gather the positions and types of memref-typed arguments in a given
@@ -942,9 +932,8 @@ struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
 /// information.
 struct FuncOpConversion : public FuncOpConversionBase {
-  FuncOpConversion(LLVM::LLVMDialect &dialect, LLVMTypeConverter &converter,
-                   bool emitCWrappers)
-      : FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
+  FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers)
+      : FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1022,7 +1011,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
 };
 
 //////////////// Support for Lowering operations on n-D vectors ////////////////
-namespace {
 // Helper struct to "unroll" operations on n-D vectors in terms of operations on
 // 1-D LLVM vectors.
 struct NDVectorTypeInfo {
@@ -1098,55 +1086,49 @@ void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
     fun(position);
   }
 }
-////////////// End Support for Lowering operations on n-D vectors //////////////
-
-// Basic lowering implementation for one-to-one rewriting from Standard Ops to
-// LLVM Dialect Ops.
-template <typename SourceOp, typename TargetOp>
-struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
-  using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
-  using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;
-
-  // Convert the type of the result to an LLVM type, pass operands as is,
-  // preserve attributes.
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    unsigned numResults = op->getNumResults();
-
-    Type packedType;
-    if (numResults != 0) {
-      packedType =
-          this->typeConverter.packFunctionResults(op->getResultTypes());
-      if (!packedType)
-        return failure();
-    }
 
-    auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
-                                           op->getAttrs());
+/// Replaces the given operaiton "op" with a new operation of type "targetOp"
+/// and given operands.
+LogicalResult LLVM::detail::oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+  unsigned numResults = op->getNumResults();
 
-    // If the operation produced 0 or 1 result, return them immediately.
-    if (numResults == 0)
-      return rewriter.eraseOp(op), success();
-    if (numResults == 1)
-      return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
-             success();
+  Type packedType;
+  if (numResults != 0) {
+    packedType = typeConverter.packFunctionResults(op->getResultTypes());
+    if (!packedType)
+      return failure();
+  }
 
-    // Otherwise, it had been converted to an operation producing a structure.
-    // Extract individual results from the structure and return them as list.
-    SmallVector<Value, 4> results;
-    results.reserve(numResults);
-    for (unsigned i = 0; i < numResults; ++i) {
-      auto type = this->typeConverter.convertType(op->getResult(i).getType());
-      results.push_back(rewriter.create<LLVM::ExtractValueOp>(
-          op->getLoc(), type, newOp.getOperation()->getResult(0),
-          rewriter.getI64ArrayAttr(i)));
-    }
-    rewriter.replaceOp(op, results);
-    return success();
+  // Create the operation through state since we don't know its C++ type.
+  OperationState state(op->getLoc(), targetOp);
+  state.addTypes(packedType);
+  state.addOperands(operands);
+  state.addAttributes(op->getAttrs());
+  Operation *newOp = rewriter.createOperation(state);
+
+  // If the operation produced 0 or 1 result, return them immediately.
+  if (numResults == 0)
+    return rewriter.eraseOp(op), success();
+  if (numResults == 1)
+    return rewriter.replaceOp(op, newOp->getResult(0)), success();
+
+  // Otherwise, it had been converted to an operation producing a structure.
+  // Extract individual results from the structure and return them as list.
+  SmallVector<Value, 4> results;
+  results.reserve(numResults);
+  for (unsigned i = 0; i < numResults; ++i) {
+    auto type = typeConverter.convertType(op->getResult(i).getType());
+    results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+        op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
   }
-};
+  rewriter.replaceOp(op, results);
+  return success();
+}
 
+////////////// End Support for Lowering operations on n-D vectors //////////////
+namespace {
 template <typename SourceOp, unsigned OpCount>
 struct OpCountValidator {
   static_assert(
@@ -1201,8 +1183,8 @@ static LogicalResult HandleMultidimensionalVectors(
 // Ops for N-ary ops with one result. This supports higher-dimensional vector
 // types.
 template <typename SourceOp, typename TargetOp, unsigned OpCount>
-struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
-  using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+struct NaryOpLLVMOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
+  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
 
   // Convert the type of the result to an LLVM type, pass operands as is,
@@ -1333,23 +1315,23 @@ struct CopySignOpLowering
   using Super::Super;
 };
 struct SelectOpLowering
-    : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
+    : public OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp> {
   using Super::Super;
 };
 struct ConstLLVMOpLowering
-    : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
+    : public OneToOneConvertToLLVMPattern<ConstantOp, LLVM::ConstantOp> {
   using Super::Super;
 };
 struct ShiftLeftOpLowering
-    : public OneToOneLLVMOpLowering<ShiftLeftOp, LLVM::ShlOp> {
+    : public OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp> {
   using Super::Super;
 };
 struct SignedShiftRightOpLowering
-    : public OneToOneLLVMOpLowering<SignedShiftRightOp, LLVM::AShrOp> {
+    : public OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp> {
   using Super::Super;
 };
 struct UnsignedShiftRightOpLowering
-    : public OneToOneLLVMOpLowering<UnsignedShiftRightOp, LLVM::LShrOp> {
+    : public OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp> {
   using Super::Super;
 };
 
@@ -1373,13 +1355,11 @@ static bool isSupportedMemRefType(MemRefType type) {
 // Alignment is obtained by allocating `alignment - 1` more bytes than requested
 // and shifting the aligned pointer relative to the allocated memory. If
 // alignment is unspecified, the two pointers are equal.
-struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
-  using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
+struct AllocOpLowering : public ConvertOpToLLVMPattern<AllocOp> {
+  using ConvertOpToLLVMPattern<AllocOp>::ConvertOpToLLVMPattern;
 
-  AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
-                  bool useAlloca = false)
-      : LLVMLegalizationPattern<AllocOp>(dialect_, converter),
-        useAlloca(useAlloca) {}
+  explicit AllocOpLowering(LLVMTypeConverter &converter, bool useAlloca = false)
+      : ConvertOpToLLVMPattern<AllocOp>(converter), useAlloca(useAlloca) {}
 
   LogicalResult match(Operation *op) const override {
     MemRefType type = cast<AllocOp>(op).getType();
@@ -1569,10 +1549,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
 // passes the pointer to the MemRef across function boundaries.
 template <typename CallOpType>
-struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
-  using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern;
+struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
+  using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
   using Super = CallOpInterfaceLowering<CallOpType>;
-  using Base = LLVMLegalizationPattern<CallOpType>;
+  using Base = ConvertOpToLLVMPattern<CallOpType>;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1639,13 +1619,12 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
 // The memref descriptor being an SSA value, there is no need to clean it up
 // in any way.
-struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
-  using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
+struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
+  using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
 
-  DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
-                    bool useAlloca = false)
-      : LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
-        useAlloca(useAlloca) {}
+  explicit DeallocOpLowering(LLVMTypeConverter &converter,
+                             bool useAlloca = false)
+      : ConvertOpToLLVMPattern<DeallocOp>(converter), useAlloca(useAlloca) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1680,8 +1659,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
 };
 
 // A `rsqrt` is converted into `1 / sqrt`.
-struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
-  using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
+struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
+  using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1737,8 +1716,8 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
   }
 };
 
-struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
-  using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
+struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
+  using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult match(Operation *op) const override {
     auto memRefCastOp = cast<MemRefCastOp>(op);
@@ -1833,8 +1812,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
 };
 
 struct DialectCastOpLowering
-    : public LLVMLegalizationPattern<LLVM::DialectCastOp> {
-  using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
+    : public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
+  using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1852,8 +1831,8 @@ struct DialectCastOpLowering
 
 // A `dim` is converted to a constant for static sizes and to an access to the
 // size stored in the memref descriptor for dynamic sizes.
-struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
-  using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
+struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
+  using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1880,8 +1859,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
 // to supported MemRef types.  Provides functionality to emit code accessing a
 // specific element of the underlying data buffer.
 template <typename Derived>
-struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
-  using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
+struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
+  using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
   using Base = LoadStoreOpLowering<Derived>;
 
   LogicalResult match(Operation *op) const override {
@@ -2029,8 +2008,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
 // an integer.  If the bit width of the source and target integer types is the
 // same, just erase the cast.  If the target type is wider, sign-extend the
 // value, otherwise truncate it.
-struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
-  using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
+struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
+  using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2064,8 +2043,8 @@ static LLVMPredType convertCmpPredicate(StdPredType pred) {
   return static_cast<LLVMPredType>(pred);
 }
 
-struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
-  using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
+struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
+  using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2083,8 +2062,8 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
   }
 };
 
-struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
-  using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
+struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
+  using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2103,39 +2082,40 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
 };
 
 struct SIToFPLowering
-    : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
+    : public OneToOneConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp> {
   using Super::Super;
 };
 
-struct FPExtLowering : public OneToOneLLVMOpLowering<FPExtOp, LLVM::FPExtOp> {
+struct FPExtLowering
+    : public OneToOneConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp> {
   using Super::Super;
 };
 
 struct FPTruncLowering
-    : public OneToOneLLVMOpLowering<FPTruncOp, LLVM::FPTruncOp> {
+    : public OneToOneConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp> {
   using Super::Super;
 };
 
 struct SignExtendIOpLowering
-    : public OneToOneLLVMOpLowering<SignExtendIOp, LLVM::SExtOp> {
+    : public OneToOneConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp> {
   using Super::Super;
 };
 
 struct TruncateIOpLowering
-    : public OneToOneLLVMOpLowering<TruncateIOp, LLVM::TruncOp> {
+    : public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
   using Super::Super;
 };
 
 struct ZeroExtendIOpLowering
-    : public OneToOneLLVMOpLowering<ZeroExtendIOp, LLVM::ZExtOp> {
+    : public OneToOneConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp> {
   using Super::Super;
 };
 
 // Base class for LLVM IR lowering terminator operations with successors.
 template <typename SourceOp, typename TargetOp>
 struct OneToOneLLVMTerminatorLowering
-    : public LLVMLegalizationPattern<SourceOp> {
-  using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+    : public ConvertOpToLLVMPattern<SourceOp> {
+  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
 
   LogicalResult
@@ -2153,8 +2133,8 @@ struct OneToOneLLVMTerminatorLowering
 // can only return 0 or 1 value, we pack multiple values into a structure type.
 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
 // necessary before returning it
-struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
-  using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
+struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
+  using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2202,8 +2182,8 @@ struct CondBranchOpLowering
 
 // The Splat operation is lowered to an insertelement + a shufflevector
 // operation. Splat to only 1-d vector result types are lowered.
-struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
-  using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
+struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
+  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2236,8 +2216,8 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
 // The Splat operation is lowered to an insertelement + a shufflevector
 // operation. Splat to only 2+-d vector result types are lowered by the
 // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
-  using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
+struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
+  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2290,8 +2270,8 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
 ///      and stride.
 /// The subview op is replaced by the descriptor.
-struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
-  using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
+struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
+  using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2418,8 +2398,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
 ///      and stride.
 /// The view op is replaced by the descriptor.
-struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
-  using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern;
+struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
+  using ConvertOpToLLVMPattern<ViewOp>::ConvertOpToLLVMPattern;
 
   // Build and return the value for the idx^th shape dimension, either by
   // returning the constant shape dimension or counting the proper dynamic size.
@@ -2535,8 +2515,8 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
 };
 
 struct AssumeAlignmentOpLowering
-    : public LLVMLegalizationPattern<AssumeAlignmentOp> {
-  using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
+    : public ConvertOpToLLVMPattern<AssumeAlignmentOp> {
+  using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -2788,7 +2768,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       UnsignedRemIOpLowering,
       UnsignedShiftRightOpLowering,
       XOrOpLowering,
-      ZeroExtendIOpLowering>(*converter.getDialect(), converter);
+      ZeroExtendIOpLowering>(converter);
   // clang-format on
 }
 
@@ -2803,19 +2783,17 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
       MemRefCastOpLowering,
       StoreOpLowering,
       SubViewOpLowering,
-      ViewOpLowering>(*converter.getDialect(), converter);
+      ViewOpLowering>(converter);
   patterns.insert<
       AllocOpLowering,
-      DeallocOpLowering>(
-        *converter.getDialect(), converter, useAlloca);
+      DeallocOpLowering>(converter, useAlloca);
   // clang-format on
 }
 
 void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
     bool emitCWrappers) {
-  patterns.insert<FuncOpConversion>(*converter.getDialect(), converter,
-                                    emitCWrappers);
+  patterns.insert<FuncOpConversion>(converter, emitCWrappers);
 }
 
 void mlir::populateStdToLLVMConversionPatterns(
@@ -2829,7 +2807,7 @@ void mlir::populateStdToLLVMConversionPatterns(
 
 static void populateStdToLLVMBarePtrFuncOpConversionPattern(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  patterns.insert<BarePtrFuncOpConversion>(*converter.getDialect(), converter);
+  patterns.insert<BarePtrFuncOpConversion>(converter);
 }
 
 void mlir::populateStdToLLVMBarePtrConversionPatterns(


        


More information about the Mlir-commits mailing list