[Mlir-commits] [mlir] 9afbca7 - [mlir] ConvertStandardToLLVM: make AllocLikeOpLowering public

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 22 02:58:29 PDT 2021


Author: Butygin
Date: 2021-05-22T12:57:45+03:00
New Revision: 9afbca746b6c93e5359e5723e5f39c21bca2f4ac

URL: https://github.com/llvm/llvm-project/commit/9afbca746b6c93e5359e5723e5f39c21bca2f4ac
DIFF: https://github.com/llvm/llvm-project/commit/9afbca746b6c93e5359e5723e5f39c21bca2f4ac.diff

LOG: [mlir] ConvertStandardToLLVM: make AllocLikeOpLowering public

It is useful for someone who wants to implement custom AllocOp LLVM lowering

Differential Revision: https://reviews.llvm.org/D102932

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 43a1afcc3f892..07b2df58ea1bb 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -606,6 +606,59 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   using ConvertToLLVMPattern::matchAndRewrite;
 };
 
+/// Lowering for AllocOp and AllocaOp.
+struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
+  using ConvertToLLVMPattern::createIndexConstant;
+  using ConvertToLLVMPattern::getIndexType;
+  using ConvertToLLVMPattern::getVoidPtrType;
+
+  explicit AllocLikeOpLLVMLowering(StringRef opName,
+                                   LLVMTypeConverter &converter)
+      : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
+
+protected:
+  // Returns 'input' aligned up to 'alignment'. Computes
+  // bumped = input + alignement - 1
+  // aligned = bumped - bumped % alignment
+  static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
+                             Value input, Value alignment);
+
+  /// Allocates the underlying buffer. Returns the allocated pointer and the
+  /// aligned pointer.
+  virtual std::tuple<Value, Value>
+  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
+                 Value sizeBytes, Operation *op) const = 0;
+
+private:
+  static MemRefType getMemRefResultType(Operation *op) {
+    return op->getResult(0).getType().cast<MemRefType>();
+  }
+
+  LogicalResult match(Operation *op) const override {
+    MemRefType memRefType = getMemRefResultType(op);
+    return success(isConvertibleAndHasIdentityMaps(memRefType));
+  }
+
+  // An `alloc` is converted into a definition of a memref descriptor value and
+  // a call to `malloc` to allocate the underlying data buffer.  The memref
+  // descriptor is of the LLVM structure type where:
+  //   1. the first element is a pointer to the allocated (typed) data buffer,
+  //   2. the second element is a pointer to the (typed) payload, aligned to the
+  //      specified alignment,
+  //   3. the remaining elements serve to store all the sizes and strides of the
+  //      memref using LLVM-converted `index` type.
+  //
+  // Alignment is performed by allocating `alignment` more bytes than
+  // requested and shifting the aligned pointer relative to the allocated
+  // memory. Note: `alignment - <minimum malloc alignment>` would actually be
+  // sufficient. If alignment is unspecified, the two pointers are equal.
+
+  // An `alloca` is converted into a definition of a memref descriptor value and
+  // an llvm.alloca to allocate the underlying data buffer.
+  void rewrite(Operation *op, ArrayRef<Value> operands,
+               ConversionPatternRewriter &rewriter) const override;
+};
+
 namespace LLVM {
 namespace detail {
 /// Replaces the given operation "op" with a new operation of type "targetOp"

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index dd9a061e83d13..7e289a5f5d995 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1831,92 +1831,10 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
   }
 };
 
-/// Lowering for AllocOp and AllocaOp.
-struct AllocLikeOpLowering : public ConvertToLLVMPattern {
-  using ConvertToLLVMPattern::createIndexConstant;
-  using ConvertToLLVMPattern::getIndexType;
-  using ConvertToLLVMPattern::getVoidPtrType;
-
-  explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
-      : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
-
-protected:
-  // Returns 'input' aligned up to 'alignment'. Computes
-  // bumped = input + alignement - 1
-  // aligned = bumped - bumped % alignment
-  static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
-                             Value input, Value alignment) {
-    Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
-    Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
-    Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
-    Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
-    return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
-  }
-
-  /// Allocates the underlying buffer. Returns the allocated pointer and the
-  /// aligned pointer.
-  virtual std::tuple<Value, Value>
-  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
-                 Value sizeBytes, Operation *op) const = 0;
-
-private:
-  static MemRefType getMemRefResultType(Operation *op) {
-    return op->getResult(0).getType().cast<MemRefType>();
-  }
-
-  LogicalResult match(Operation *op) const override {
-    MemRefType memRefType = getMemRefResultType(op);
-    return success(isConvertibleAndHasIdentityMaps(memRefType));
-  }
-
-  // An `alloc` is converted into a definition of a memref descriptor value and
-  // a call to `malloc` to allocate the underlying data buffer.  The memref
-  // descriptor is of the LLVM structure type where:
-  //   1. the first element is a pointer to the allocated (typed) data buffer,
-  //   2. the second element is a pointer to the (typed) payload, aligned to the
-  //      specified alignment,
-  //   3. the remaining elements serve to store all the sizes and strides of the
-  //      memref using LLVM-converted `index` type.
-  //
-  // Alignment is performed by allocating `alignment` more bytes than
-  // requested and shifting the aligned pointer relative to the allocated
-  // memory. Note: `alignment - <minimum malloc alignment>` would actually be
-  // sufficient. If alignment is unspecified, the two pointers are equal.
-
-  // An `alloca` is converted into a definition of a memref descriptor value and
-  // an llvm.alloca to allocate the underlying data buffer.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const override {
-    MemRefType memRefType = getMemRefResultType(op);
-    auto loc = op->getLoc();
-
-    // Get actual sizes of the memref as values: static sizes are constant
-    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
-    // zero-dimensional memref, assume a scalar (size 1).
-    SmallVector<Value, 4> sizes;
-    SmallVector<Value, 4> strides;
-    Value sizeBytes;
-    this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
-                                   strides, sizeBytes);
-
-    // Allocate the underlying buffer.
-    Value allocatedPtr;
-    Value alignedPtr;
-    std::tie(allocatedPtr, alignedPtr) =
-        this->allocateBuffer(rewriter, loc, sizeBytes, op);
-
-    // Create the MemRef descriptor.
-    auto memRefDescriptor = this->createMemRefDescriptor(
-        loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
-
-    // Return the final value of the descriptor.
-    rewriter.replaceOp(op, {memRefDescriptor});
-  }
-};
-
-struct AllocOpLowering : public AllocLikeOpLowering {
+struct AllocOpLowering : public AllocLikeOpLLVMLowering {
   AllocOpLowering(LLVMTypeConverter &converter)
-      : AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
+      : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
+                                converter) {}
 
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
                                           Location loc, Value sizeBytes,
@@ -1967,9 +1885,10 @@ struct AllocOpLowering : public AllocLikeOpLowering {
   }
 };
 
-struct AlignedAllocOpLowering : public AllocLikeOpLowering {
+struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
   AlignedAllocOpLowering(LLVMTypeConverter &converter)
-      : AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
+      : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
+                                converter) {}
 
   /// Returns the memref's element size in bytes.
   // TODO: there are other places where this is used. Expose publicly?
@@ -2047,9 +1966,10 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
 // Out of line definition, required till C++17.
 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
 
-struct AllocaOpLowering : public AllocLikeOpLowering {
+struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
   AllocaOpLowering(LLVMTypeConverter &converter)
-      : AllocLikeOpLowering(memref::AllocaOp::getOperationName(), converter) {}
+      : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
+                                converter) {}
 
   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
   /// is set to null for stack allocations. `accessAlignment` is set if
@@ -2310,10 +2230,10 @@ struct GlobalMemrefOpLowering
 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
 /// the first element stashed into the descriptor. This reuses
 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
-struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
+struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
-      : AllocLikeOpLowering(memref::GetGlobalOp::getOperationName(),
-                            converter) {}
+      : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
+                                converter) {}
 
   /// Buffer "allocation" for memref.get_global op is getting the address of
   /// the global variable referenced.
@@ -4195,6 +4115,45 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
 };
 } // end namespace
 
+Value AllocLikeOpLLVMLowering::createAligned(
+    ConversionPatternRewriter &rewriter, Location loc, Value input,
+    Value alignment) {
+  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
+  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
+  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
+  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
+  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+}
+
+void AllocLikeOpLLVMLowering::rewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  MemRefType memRefType = getMemRefResultType(op);
+  auto loc = op->getLoc();
+
+  // Get actual sizes of the memref as values: static sizes are constant
+  // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+  // zero-dimensional memref, assume a scalar (size 1).
+  SmallVector<Value, 4> sizes;
+  SmallVector<Value, 4> strides;
+  Value sizeBytes;
+  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
+                                 strides, sizeBytes);
+
+  // Allocate the underlying buffer.
+  Value allocatedPtr;
+  Value alignedPtr;
+  std::tie(allocatedPtr, alignedPtr) =
+      this->allocateBuffer(rewriter, loc, sizeBytes, op);
+
+  // Create the MemRef descriptor.
+  auto memRefDescriptor = this->createMemRefDescriptor(
+      loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
+
+  // Return the final value of the descriptor.
+  rewriter.replaceOp(op, {memRefDescriptor});
+}
+
 mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
     : ConversionTarget(ctx) {
   this->addLegalDialect<LLVM::LLVMDialect>();


        


More information about the Mlir-commits mailing list