[Mlir-commits] [mlir] 9925168 - [mlir] Convert `memref_reshape` to LLVM.

Alexander Belyaev llvmlistbot at llvm.org
Tue Nov 3 02:39:32 PST 2020


Author: Alexander Belyaev
Date: 2020-11-03T11:39:08+01:00
New Revision: 992516857691edecbbefc8e55c402908884301ba

URL: https://github.com/llvm/llvm-project/commit/992516857691edecbbefc8e55c402908884301ba
DIFF: https://github.com/llvm/llvm-project/commit/992516857691edecbbefc8e55c402908884301ba.diff

LOG: [mlir] Convert `memref_reshape` to LLVM.

https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15

Differential Revision: https://reviews.llvm.org/D90377

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/mlir-cpu-runner/memref_reshape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 36734f809175..c52de63224b1 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -399,6 +399,65 @@ class UnrankedMemRefDescriptor : public StructBuilder {
                            LLVMTypeConverter &typeConverter,
                            ArrayRef<UnrankedMemRefDescriptor> values,
                            SmallVectorImpl<Value> &sizes);
+
+  /// TODO: The following accessors don't take alignment rules between elements
+  /// of the descriptor struct into account. For some architectures, it might be
+  /// necessary to extend them and to use `llvm::DataLayout` contained in
+  /// `LLVMTypeConverter`.
+
+  /// Builds IR extracting the allocated pointer from the descriptor.
+  static Value allocatedPtr(OpBuilder &builder, Location loc,
+                            Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+  /// Builds IR inserting the allocated pointer into the descriptor.
+  static void setAllocatedPtr(OpBuilder &builder, Location loc,
+                              Value memRefDescPtr,
+                              LLVM::LLVMType elemPtrPtrType,
+                              Value allocatedPtr);
+
+  /// Builds IR extracting the aligned pointer from the descriptor.
+  static Value alignedPtr(OpBuilder &builder, Location loc,
+                          LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+                          LLVM::LLVMType elemPtrPtrType);
+  /// Builds IR inserting the aligned pointer into the descriptor.
+  static void setAlignedPtr(OpBuilder &builder, Location loc,
+                            LLVMTypeConverter &typeConverter,
+                            Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType,
+                            Value alignedPtr);
+
+  /// Builds IR extracting the offset from the descriptor.
+  static Value offset(OpBuilder &builder, Location loc,
+                      LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+                      LLVM::LLVMType elemPtrPtrType);
+  /// Builds IR inserting the offset into the descriptor.
+  static void setOffset(OpBuilder &builder, Location loc,
+                        LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+                        LLVM::LLVMType elemPtrPtrType, Value offset);
+
+  /// Builds IR extracting the pointer to the first element of the size array.
+  static Value sizeBasePtr(OpBuilder &builder, Location loc,
+                           LLVMTypeConverter &typeConverter,
+                           Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+  /// Builds IR extracting the size[index] from the descriptor.
+  static Value size(OpBuilder &builder, Location loc,
+                    LLVMTypeConverter typeConverter, Value sizeBasePtr,
+                    Value index);
+  /// Builds IR inserting the size[index] into the descriptor.
+  static void setSize(OpBuilder &builder, Location loc,
+                      LLVMTypeConverter typeConverter, Value sizeBasePtr,
+                      Value index, Value size);
+
+  /// Builds IR extracting the pointer to the first element of the stride array.
+  static Value strideBasePtr(OpBuilder &builder, Location loc,
+                             LLVMTypeConverter &typeConverter,
+                             Value sizeBasePtr, Value rank);
+  /// Builds IR extracting the stride[index] from the descriptor.
+  static Value stride(OpBuilder &builder, Location loc,
+                      LLVMTypeConverter typeConverter, Value strideBasePtr,
+                      Value index, Value stride);
+  /// Builds IR inserting the stride[index] into the descriptor.
+  static void setStride(OpBuilder &builder, Location loc,
+                        LLVMTypeConverter typeConverter, Value strideBasePtr,
+                        Value index, Value stride);
 };
 
 /// Base class for operation conversions targeting the LLVM IR dialect. It

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 152dbd1e990e..b6dc2cad4d37 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -865,6 +865,155 @@ void UnrankedMemRefDescriptor::computeSizes(
   }
 }
 
+Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
+                                             Value memRefDescPtr,
+                                             LLVM::LLVMType elemPtrPtrType) {
+
+  Value elementPtrPtr =
+      builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+  return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
+}
+
+void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
+                                               Value memRefDescPtr,
+                                               LLVM::LLVMType elemPtrPtrType,
+                                               Value allocatedPtr) {
+  Value elementPtrPtr =
+      builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+  builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
+}
+
+Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
+                                           LLVMTypeConverter &typeConverter,
+                                           Value memRefDescPtr,
+                                           LLVM::LLVMType elemPtrPtrType) {
+  Value elementPtrPtr =
+      builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+  Value one =
+      createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
+  Value alignedGep = builder.create<LLVM::GEPOp>(
+      loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
+  return builder.create<LLVM::LoadOp>(loc, alignedGep);
+}
+
+void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
+                                             LLVMTypeConverter &typeConverter,
+                                             Value memRefDescPtr,
+                                             LLVM::LLVMType elemPtrPtrType,
+                                             Value alignedPtr) {
+  Value elementPtrPtr =
+      builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+  Value one =
+      createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
+  Value alignedGep = builder.create<LLVM::GEPOp>(
+      loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
+  builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
+}
+
+Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
+                                       LLVMTypeConverter &typeConverter,
+                                       Value memRefDescPtr,
+                                       LLVM::LLVMType elemPtrPtrType) {
+  Value elementPtrPtr =
+      builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+  Value two =
+      createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
+  Value offsetGep = builder.create<LLVM::GEPOp>(
+      loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
+  offsetGep = builder.create<LLVM::BitcastOp>(
+      loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+  return builder.create<LLVM::LoadOp>(loc, offsetGep);
+}
+
+void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
+                                         LLVMTypeConverter &typeConverter,
+                                         Value memRefDescPtr,
+                                         LLVM::LLVMType elemPtrPtrType,
+                                         Value offset) {
+  Value elementPtrPtr =
+      builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+  Value two =
+      createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
+  Value offsetGep = builder.create<LLVM::GEPOp>(
+      loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
+  offsetGep = builder.create<LLVM::BitcastOp>(
+      loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+  builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
+}
+
+Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
+                                            LLVMTypeConverter &typeConverter,
+                                            Value memRefDescPtr,
+                                            LLVM::LLVMType elemPtrPtrType) {
+  LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
+  LLVM::LLVMType indexTy = typeConverter.getIndexType();
+  LLVM::LLVMType structPtrTy =
+      LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
+          .getPointerTo();
+  Value structPtr =
+      builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
+
+  LLVM::LLVMType int32_type =
+      unwrap(typeConverter.convertType(builder.getI32Type()));
+  Value zero =
+      createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
+  Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
+                                                 builder.getI32IntegerAttr(3));
+  return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
+                                     ValueRange({zero, three}));
+}
+
+Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
+                                     LLVMTypeConverter typeConverter,
+                                     Value sizeBasePtr, Value index) {
+  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
+                                                   ValueRange({index}));
+  return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
+}
+
+void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
+                                       LLVMTypeConverter typeConverter,
+                                       Value sizeBasePtr, Value index,
+                                       Value size) {
+  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
+                                                   ValueRange({index}));
+  builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
+}
+
+Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
+                                              LLVMTypeConverter &typeConverter,
+                                              Value sizeBasePtr, Value rank) {
+  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
+                                     ValueRange({rank}));
+}
+
+Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
+                                       LLVMTypeConverter typeConverter,
+                                       Value strideBasePtr, Value index,
+                                       Value stride) {
+  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  Value strideStoreGep = builder.create<LLVM::GEPOp>(
+      loc, indexPtrTy, strideBasePtr, ValueRange({index}));
+  return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
+}
+
+void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
+                                         LLVMTypeConverter typeConverter,
+                                         Value strideBasePtr, Value index,
+                                         Value stride) {
+  LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+  Value strideStoreGep = builder.create<LLVM::GEPOp>(
+      loc, indexPtrTy, strideBasePtr, ValueRange({index}));
+  builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
+}
+
 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
   return *typeConverter.getDialect();
 }
@@ -2417,6 +2566,49 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
   }
 };
 
+/// Extracts allocated, aligned pointers and offset from a ranked or unranked
+/// memref type. In unranked case, the fields are extracted from the underlying
+/// ranked descriptor.
+static void extractPointersAndOffset(Location loc,
+                                     ConversionPatternRewriter &rewriter,
+                                     LLVMTypeConverter &typeConverter,
+                                     Value originalOperand,
+                                     Value convertedOperand,
+                                     Value *allocatedPtr, Value *alignedPtr,
+                                     Value *offset = nullptr) {
+  Type operandType = originalOperand.getType();
+  if (operandType.isa<MemRefType>()) {
+    MemRefDescriptor desc(convertedOperand);
+    *allocatedPtr = desc.allocatedPtr(rewriter, loc);
+    *alignedPtr = desc.alignedPtr(rewriter, loc);
+    if (offset != nullptr)
+      *offset = desc.offset(rewriter, loc);
+    return;
+  }
+
+  unsigned memorySpace =
+      operandType.cast<UnrankedMemRefType>().getMemorySpace();
+  Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
+  LLVM::LLVMType llvmElementType =
+      unwrap(typeConverter.convertType(elementType));
+  LLVM::LLVMType elementPtrPtrType =
+      llvmElementType.getPointerTo(memorySpace).getPointerTo();
+
+  // Extract pointer to the underlying ranked memref descriptor and cast it to
+  // ElemType**.
+  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
+  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
+
+  *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
+      rewriter, loc, underlyingDescPtr, elementPtrPtrType);
+  *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+      rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+  if (offset != nullptr) {
+    *offset = UnrankedMemRefDescriptor::offset(
+        rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+  }
+}
+
 struct MemRefReinterpretCastOpLowering
     : public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
   using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
@@ -2455,8 +2647,8 @@ struct MemRefReinterpretCastOpLowering
 
     // Set allocated and aligned pointers.
     Value allocatedPtr, alignedPtr;
-    extractPointers(loc, rewriter, castOp.source(), adaptor.source(),
-                    &allocatedPtr, &alignedPtr);
+    extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(),
+                             adaptor.source(), &allocatedPtr, &alignedPtr);
     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
     desc.setAlignedPtr(rewriter, loc, alignedPtr);
 
@@ -2483,45 +2675,155 @@ struct MemRefReinterpretCastOpLowering
     *descriptor = desc;
     return success();
   }
+};
 
-  void extractPointers(Location loc, ConversionPatternRewriter &rewriter,
-                       Value originalOperand, Value convertedOperand,
-                       Value *allocatedPtr, Value *alignedPtr) const {
-    Type operandType = originalOperand.getType();
-    if (operandType.isa<MemRefType>()) {
-      MemRefDescriptor desc(convertedOperand);
-      *allocatedPtr = desc.allocatedPtr(rewriter, loc);
-      *alignedPtr = desc.alignedPtr(rewriter, loc);
-      return;
-    }
+struct MemRefReshapeOpLowering
+    : public ConvertOpToLLVMPattern<MemRefReshapeOp> {
+  using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
 
-    unsigned memorySpace =
-        operandType.cast<UnrankedMemRefType>().getMemorySpace();
-    Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
-    LLVM::LLVMType llvmElementType =
-        typeConverter.convertType(elementType).cast<LLVM::LLVMType>();
-    LLVM::LLVMType elementPtrPtrType =
-        llvmElementType.getPointerTo(memorySpace).getPointerTo();
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reshapeOp = cast<MemRefReshapeOp>(op);
 
-    // Extract pointer to the underlying ranked memref descriptor and cast it to
-    // ElemType**.
-    UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
-    Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
-    Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, elementPtrPtrType, underlyingDescPtr);
+    MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
+    Type srcType = reshapeOp.source().getType();
 
-    LLVM::LLVMType int32Type =
-        typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>();
+    Value descriptor;
+    if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
+                                               adaptor, &descriptor)))
+      return failure();
+    rewriter.replaceOp(op, {descriptor});
+    return success();
+  }
+
+private:
+  LogicalResult
+  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
+                                  Type srcType, MemRefReshapeOp reshapeOp,
+                                  MemRefReshapeOp::Adaptor adaptor,
+                                  Value *descriptor) const {
+    // Conversion for statically-known shape args is performed via
+    // `memref_reinterpret_cast`.
+    auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
+    if (shapeMemRefType.hasStaticShape())
+      return failure();
 
-    // Extract and set allocated pointer.
-    *allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr);
+    // The shape is a rank-1 tensor with unknown length.
+    Location loc = reshapeOp.getLoc();
+    MemRefDescriptor shapeDesc(adaptor.shape());
+    Value resultRank = shapeDesc.size(rewriter, loc, 0);
 
-    // Extract and set aligned pointer.
-    Value one = rewriter.create<LLVM::ConstantOp>(
-        loc, int32Type, rewriter.getI32IntegerAttr(1));
-    Value alignedGep = rewriter.create<LLVM::GEPOp>(
-        loc, elementPtrPtrType, elementPtrPtr, ValueRange({one}));
-    *alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep);
+    // Extract address space and element type.
+    auto targetType =
+        reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
+    unsigned addressSpace = targetType.getMemorySpace();
+    Type elementType = targetType.getElementType();
+
+    // Create the unranked memref descriptor that holds the ranked one. The
+    // inner descriptor is allocated on stack.
+    auto targetDesc = UnrankedMemRefDescriptor::undef(
+        rewriter, loc, unwrap(typeConverter.convertType(targetType)));
+    targetDesc.setRank(rewriter, loc, resultRank);
+    SmallVector<Value, 4> sizes;
+    UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
+                                           targetDesc, sizes);
+    Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
+        loc, getVoidPtrType(), sizes.front(), llvm::None);
+    targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
+
+    // Extract pointers and offset from the source memref.
+    Value allocatedPtr, alignedPtr, offset;
+    extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(),
+                             adaptor.source(), &allocatedPtr, &alignedPtr,
+                             &offset);
+
+    // Set pointers and offset.
+    LLVM::LLVMType llvmElementType =
+        unwrap(typeConverter.convertType(elementType));
+    LLVM::LLVMType elementPtrPtrType =
+        llvmElementType.getPointerTo(addressSpace).getPointerTo();
+    UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
+                                              elementPtrPtrType, allocatedPtr);
+    UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter,
+                                            underlyingDescPtr,
+                                            elementPtrPtrType, alignedPtr);
+    UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter,
+                                        underlyingDescPtr, elementPtrPtrType,
+                                        offset);
+
+    // Use the offset pointer as base for further addressing. Copy over the new
+    // shape and compute strides. For this, we create a loop from rank-1 to 0.
+    Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
+        rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+    Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
+        rewriter, loc, typeConverter, targetSizesBase, resultRank);
+    Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
+    Value oneIndex = createIndexConstant(rewriter, loc, 1);
+    Value resultRankMinusOne =
+        rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
+
+    Block *initBlock = rewriter.getInsertionBlock();
+    LLVM::LLVMType indexType = typeConverter.getIndexType();
+    Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
+
+    Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
+                                            {indexType, indexType});
+
+    // Iterate over the remaining ops in initBlock and move them to condBlock.
+    BlockAndValueMapping map;
+    for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) {
+      rewriter.clone(*it, map);
+      rewriter.eraseOp(&*it);
+    }
+
+    rewriter.setInsertionPointToEnd(initBlock);
+    rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
+                                condBlock);
+    rewriter.setInsertionPointToStart(condBlock);
+    Value indexArg = condBlock->getArgument(0);
+    Value strideArg = condBlock->getArgument(1);
+
+    Value zeroIndex = createIndexConstant(rewriter, loc, 0);
+    Value pred = rewriter.create<LLVM::ICmpOp>(
+        loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
+        LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
+
+    Block *bodyBlock =
+        rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
+    rewriter.setInsertionPointToStart(bodyBlock);
+
+    // Copy size from shape to descriptor.
+    LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
+    Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
+        loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
+    Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
+    UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter,
+                                      targetSizesBase, indexArg, size);
+
+    // Write stride value and compute next one.
+    UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter,
+                                        targetStridesBase, indexArg, strideArg);
+    Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
+
+    // Decrement loop counter and branch back.
+    Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
+    rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
+                                condBlock);
+
+    Block *remainder =
+        rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
+
+    // Hook up the cond exit to the remainder.
+    rewriter.setInsertionPointToEnd(condBlock);
+    rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
+                                    llvm::None);
+
+    // Reset position to beginning of new remainder block.
+    rewriter.setInsertionPointToStart(remainder);
+
+    *descriptor = targetDesc;
+    return success();
   }
 };
 
@@ -3642,6 +3944,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
       LoadOpLowering,
       MemRefCastOpLowering,
       MemRefReinterpretCastOpLowering,
+      MemRefReshapeOpLowering,
       RankOpLowering,
       StoreOpLowering,
       SubViewOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 8447474484e2..de59472f0713 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -478,9 +478,10 @@ func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
 // CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
 // CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
 // CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
-// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
-// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]]
-// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i32) -> !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]]
+// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i64) -> !llvm.ptr<ptr<float>>
 // CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
 // CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
 // CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
@@ -489,3 +490,73 @@ func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
 // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]]
 // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]]
 // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
+
+// CHECK-LABEL: @memref_reshape
+func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
+  %output = memref_reshape %input(%shape)
+                : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
+  return
+}
+// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[INPUT_TY:!.*]]
+// CHECK: [[SHAPE:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : [[SHAPE_TY:!.*]]
+// CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]]
+// CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+// CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr<i8>)>
+
+// Compute size in bytes to allocate result ranked descriptor
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
+// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
+// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : !llvm.i64
+// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
+// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x !llvm.i8
+// CHECK: llvm.insertvalue [[UNDERLYING_DESC]], [[UNRANKED_OUT_1]][1]
+
+// Set allocated, aligned pointers and offset.
+// CHECK: [[ALLOC_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[INPUT_TY]]
+// CHECK: [[ALIGN_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[INPUT_TY]]
+// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] : [[INPUT_TY]]
+// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
+// CHECK-SAME:                     !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: llvm.store [[ALLOC_PTR]], [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]]
+// CHECK: llvm.store [[ALIGN_PTR]], [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR_PTR__:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: [[OFFSET_PTR_:%.*]] = llvm.getelementptr [[BASE_PTR_PTR__]]{{\[}}[[C2]]]
+// CHECK: [[OFFSET_PTR:%.*]] = llvm.bitcast [[OFFSET_PTR_]]
+// CHECK: llvm.store [[OFFSET]], [[OFFSET_PTR]] : !llvm.ptr<i64>
+
+// Iterate over shape operand in reverse order and set sizes and strides.
+// CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
+// CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<float>, ptr<float>, i64, i64)>>
+// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
+// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]]
+// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]]
+// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]]
+// CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[RANK_MIN_1:%.*]] = llvm.sub [[RANK]], [[C1_]] : !llvm.i64
+// CHECK: llvm.br ^bb1([[RANK_MIN_1]], [[C1_]] : !llvm.i64, !llvm.i64)
+
+// CHECK: ^bb1([[DIM:%.*]]: !llvm.i64, [[CUR_STRIDE:%.*]]: !llvm.i64):
+// CHECK:   [[C0_:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK:   [[COND:%.*]] = llvm.icmp "sge" [[DIM]], [[C0_]] : !llvm.i64
+// CHECK:   llvm.cond_br [[COND]], ^bb2, ^bb3
+
+// CHECK: ^bb2:
+// CHECK:   [[SIZE_PTR:%.*]] = llvm.getelementptr [[SHAPE_IN_PTR]]{{\[}}[[DIM]]]
+// CHECK:   [[SIZE:%.*]] = llvm.load [[SIZE_PTR]] : !llvm.ptr<i64>
+// CHECK:   [[TARGET_SIZE_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[DIM]]]
+// CHECK:   llvm.store [[SIZE]], [[TARGET_SIZE_PTR]] : !llvm.ptr<i64>
+// CHECK:   [[TARGET_STRIDE_PTR:%.*]] = llvm.getelementptr [[STRIDES_PTR]]{{\[}}[[DIM]]]
+// CHECK:   llvm.store [[CUR_STRIDE]], [[TARGET_STRIDE_PTR]] : !llvm.ptr<i64>
+// CHECK:   [[UPDATE_STRIDE:%.*]] = llvm.mul [[CUR_STRIDE]], [[SIZE]] : !llvm.i64
+// CHECK:   [[STRIDE_COND:%.*]] = llvm.sub [[DIM]], [[C1_]] : !llvm.i64
+// CHECK:   llvm.br ^bb1([[STRIDE_COND]], [[UPDATE_STRIDE]] : !llvm.i64, !llvm.i64)
+
+// CHECK: ^bb3:
+// CHECK:   llvm.return

diff  --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
index 96a8ae16ae6d..5ec3b31a76b2 100644
--- a/mlir/test/mlir-cpu-runner/memref_reshape.mlir
+++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \
+// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
 // RUN: | mlir-cpu-runner -e main -entry-point-result=void \
 // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
 // RUN: | FileCheck %s
@@ -39,6 +39,10 @@ func @main() -> () {
     : (memref<2x3xf32>, memref<2xindex>) -> ()
   call @reshape_unranked_memref_to_ranked(%input, %shape)
     : (memref<2x3xf32>, memref<2xindex>) -> ()
+  call @reshape_ranked_memref_to_unranked(%input, %shape)
+    : (memref<2x3xf32>, memref<2xindex>) -> ()
+  call @reshape_unranked_memref_to_unranked(%input, %shape)
+    : (memref<2x3xf32>, memref<2xindex>) -> ()
   return
 }
 
@@ -50,9 +54,9 @@ func @reshape_ranked_memref_to_ranked(%input : memref<2x3xf32>,
   %unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
   call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
   // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
-  // CHECK: [0,   1],                          
-  // CHECK: [2,   3],                          
-  // CHECK: [4,   5]  
+  // CHECK: [0,   1],
+  // CHECK: [2,   3],
+  // CHECK: [4,   5]
   return
 }
 
@@ -65,8 +69,37 @@ func @reshape_unranked_memref_to_ranked(%input : memref<2x3xf32>,
   %unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
   call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
   // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
-  // CHECK: [0,   1],                          
-  // CHECK: [2,   3],                          
-  // CHECK: [4,   5]  
+  // CHECK: [0,   1],
+  // CHECK: [2,   3],
+  // CHECK: [4,   5]
+  return
+}
+
+func @reshape_ranked_memref_to_unranked(%input : memref<2x3xf32>,
+                                        %shape : memref<2xindex>) {
+  %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref<?xindex>
+  %output = memref_reshape %input(%dyn_size_shape)
+                : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
+
+  call @print_memref_f32(%output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+  // CHECK: [0,   1],
+  // CHECK: [2,   3],
+  // CHECK: [4,   5]
+  return
+}
+
+func @reshape_unranked_memref_to_unranked(%input : memref<2x3xf32>,
+                                          %shape : memref<2xindex>) {
+  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+  %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref<?xindex>
+  %output = memref_reshape %input(%dyn_size_shape)
+                : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
+
+  call @print_memref_f32(%output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+  // CHECK: [0,   1],
+  // CHECK: [2,   3],
+  // CHECK: [4,   5]
   return
 }


        


More information about the Mlir-commits mailing list