[Mlir-commits] [mlir] 9925168 - [mlir] Convert `memref_reshape` to LLVM.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Nov 3 02:39:32 PST 2020
Author: Alexander Belyaev
Date: 2020-11-03T11:39:08+01:00
New Revision: 992516857691edecbbefc8e55c402908884301ba
URL: https://github.com/llvm/llvm-project/commit/992516857691edecbbefc8e55c402908884301ba
DIFF: https://github.com/llvm/llvm-project/commit/992516857691edecbbefc8e55c402908884301ba.diff
LOG: [mlir] Convert `memref_reshape` to LLVM.
https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15
Differential Revision: https://reviews.llvm.org/D90377
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/mlir-cpu-runner/memref_reshape.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 36734f809175..c52de63224b1 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -399,6 +399,65 @@ class UnrankedMemRefDescriptor : public StructBuilder {
LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values,
SmallVectorImpl<Value> &sizes);
+
+ /// TODO: The following accessors don't take alignment rules between elements
+ /// of the descriptor struct into account. For some architectures, it might be
+ /// necessary to extend them and to use `llvm::DataLayout` contained in
+ /// `LLVMTypeConverter`.
+
+ /// Builds IR extracting the allocated pointer from the descriptor.
+ static Value allocatedPtr(OpBuilder &builder, Location loc,
+ Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+ /// Builds IR inserting the allocated pointer into the descriptor.
+ static void setAllocatedPtr(OpBuilder &builder, Location loc,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType,
+ Value allocatedPtr);
+
+ /// Builds IR extracting the aligned pointer from the descriptor.
+ static Value alignedPtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType);
+ /// Builds IR inserting the aligned pointer into the descriptor.
+ static void setAlignedPtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType,
+ Value alignedPtr);
+
+ /// Builds IR extracting the offset from the descriptor.
+ static Value offset(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType);
+ /// Builds IR inserting the offset into the descriptor.
+ static void setOffset(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType, Value offset);
+
+ /// Builds IR extracting the pointer to the first element of the size array.
+ static Value sizeBasePtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+ /// Builds IR extracting the size[index] from the descriptor.
+ static Value size(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter, Value sizeBasePtr,
+ Value index);
+ /// Builds IR inserting the size[index] into the descriptor.
+ static void setSize(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter, Value sizeBasePtr,
+ Value index, Value size);
+
+ /// Builds IR extracting the pointer to the first element of the stride array.
+ static Value strideBasePtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value sizeBasePtr, Value rank);
+ /// Builds IR extracting the stride[index] from the descriptor.
+ static Value stride(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter, Value strideBasePtr,
+ Value index, Value stride);
+ /// Builds IR inserting the stride[index] into the descriptor.
+ static void setStride(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter, Value strideBasePtr,
+ Value index, Value stride);
};
/// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 152dbd1e990e..b6dc2cad4d37 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -865,6 +865,155 @@ void UnrankedMemRefDescriptor::computeSizes(
}
}
+Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType) {
+
+ Value elementPtrPtr =
+ builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+ return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
+}
+
+void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType,
+ Value allocatedPtr) {
+ Value elementPtrPtr =
+ builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+ builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
+}
+
+Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType) {
+ Value elementPtrPtr =
+ builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+ Value one =
+ createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
+ Value alignedGep = builder.create<LLVM::GEPOp>(
+ loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
+ return builder.create<LLVM::LoadOp>(loc, alignedGep);
+}
+
+void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType,
+ Value alignedPtr) {
+ Value elementPtrPtr =
+ builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+ Value one =
+ createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
+ Value alignedGep = builder.create<LLVM::GEPOp>(
+ loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
+ builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
+}
+
+Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType) {
+ Value elementPtrPtr =
+ builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+ Value two =
+ createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
+ Value offsetGep = builder.create<LLVM::GEPOp>(
+ loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
+ offsetGep = builder.create<LLVM::BitcastOp>(
+ loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+ return builder.create<LLVM::LoadOp>(loc, offsetGep);
+}
+
+void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType,
+ Value offset) {
+ Value elementPtrPtr =
+ builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
+
+ Value two =
+ createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
+ Value offsetGep = builder.create<LLVM::GEPOp>(
+ loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
+ offsetGep = builder.create<LLVM::BitcastOp>(
+ loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+ builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
+}
+
+Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr,
+ LLVM::LLVMType elemPtrPtrType) {
+ LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
+ LLVM::LLVMType indexTy = typeConverter.getIndexType();
+ LLVM::LLVMType structPtrTy =
+ LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
+ .getPointerTo();
+ Value structPtr =
+ builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
+
+ LLVM::LLVMType int32_type =
+ unwrap(typeConverter.convertType(builder.getI32Type()));
+ Value zero =
+ createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
+ Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
+ builder.getI32IntegerAttr(3));
+ return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
+ ValueRange({zero, three}));
+}
+
+Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter,
+ Value sizeBasePtr, Value index) {
+ LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
+ ValueRange({index}));
+ return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
+}
+
+void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter,
+ Value sizeBasePtr, Value index,
+ Value size) {
+ LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
+ ValueRange({index}));
+ builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
+}
+
+Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ Value sizeBasePtr, Value rank) {
+ LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
+ ValueRange({rank}));
+}
+
+Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter,
+ Value strideBasePtr, Value index,
+ Value stride) {
+ LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ Value strideStoreGep = builder.create<LLVM::GEPOp>(
+ loc, indexPtrTy, strideBasePtr, ValueRange({index}));
+ return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
+}
+
+void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
+ LLVMTypeConverter typeConverter,
+ Value strideBasePtr, Value index,
+ Value stride) {
+ LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ Value strideStoreGep = builder.create<LLVM::GEPOp>(
+ loc, indexPtrTy, strideBasePtr, ValueRange({index}));
+ builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
+}
+
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
return *typeConverter.getDialect();
}
@@ -2417,6 +2566,49 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
}
};
+/// Extracts allocated, aligned pointers and offset from a ranked or unranked
+/// memref type. In unranked case, the fields are extracted from the underlying
+/// ranked descriptor.
+static void extractPointersAndOffset(Location loc,
+ ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter,
+ Value originalOperand,
+ Value convertedOperand,
+ Value *allocatedPtr, Value *alignedPtr,
+ Value *offset = nullptr) {
+ Type operandType = originalOperand.getType();
+ if (operandType.isa<MemRefType>()) {
+ MemRefDescriptor desc(convertedOperand);
+ *allocatedPtr = desc.allocatedPtr(rewriter, loc);
+ *alignedPtr = desc.alignedPtr(rewriter, loc);
+ if (offset != nullptr)
+ *offset = desc.offset(rewriter, loc);
+ return;
+ }
+
+ unsigned memorySpace =
+ operandType.cast<UnrankedMemRefType>().getMemorySpace();
+ Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
+ LLVM::LLVMType llvmElementType =
+ unwrap(typeConverter.convertType(elementType));
+ LLVM::LLVMType elementPtrPtrType =
+ llvmElementType.getPointerTo(memorySpace).getPointerTo();
+
+ // Extract pointer to the underlying ranked memref descriptor and cast it to
+ // ElemType**.
+ UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
+ Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
+
+ *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
+ rewriter, loc, underlyingDescPtr, elementPtrPtrType);
+ *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+ rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+ if (offset != nullptr) {
+ *offset = UnrankedMemRefDescriptor::offset(
+ rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+ }
+}
+
struct MemRefReinterpretCastOpLowering
: public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
@@ -2455,8 +2647,8 @@ struct MemRefReinterpretCastOpLowering
// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
- extractPointers(loc, rewriter, castOp.source(), adaptor.source(),
- &allocatedPtr, &alignedPtr);
+ extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(),
+ adaptor.source(), &allocatedPtr, &alignedPtr);
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
desc.setAlignedPtr(rewriter, loc, alignedPtr);
@@ -2483,45 +2675,155 @@ struct MemRefReinterpretCastOpLowering
*descriptor = desc;
return success();
}
+};
- void extractPointers(Location loc, ConversionPatternRewriter &rewriter,
- Value originalOperand, Value convertedOperand,
- Value *allocatedPtr, Value *alignedPtr) const {
- Type operandType = originalOperand.getType();
- if (operandType.isa<MemRefType>()) {
- MemRefDescriptor desc(convertedOperand);
- *allocatedPtr = desc.allocatedPtr(rewriter, loc);
- *alignedPtr = desc.alignedPtr(rewriter, loc);
- return;
- }
+struct MemRefReshapeOpLowering
+ : public ConvertOpToLLVMPattern<MemRefReshapeOp> {
+ using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
- unsigned memorySpace =
- operandType.cast<UnrankedMemRefType>().getMemorySpace();
- Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
- LLVM::LLVMType llvmElementType =
- typeConverter.convertType(elementType).cast<LLVM::LLVMType>();
- LLVM::LLVMType elementPtrPtrType =
- llvmElementType.getPointerTo(memorySpace).getPointerTo();
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reshapeOp = cast<MemRefReshapeOp>(op);
- // Extract pointer to the underlying ranked memref descriptor and cast it to
- // ElemType**.
- UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
- Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
- Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>(
- loc, elementPtrPtrType, underlyingDescPtr);
+ MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
+ Type srcType = reshapeOp.source().getType();
- LLVM::LLVMType int32Type =
- typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>();
+ Value descriptor;
+ if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
+ adaptor, &descriptor)))
+ return failure();
+ rewriter.replaceOp(op, {descriptor});
+ return success();
+ }
+
+private:
+ LogicalResult
+ convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
+ Type srcType, MemRefReshapeOp reshapeOp,
+ MemRefReshapeOp::Adaptor adaptor,
+ Value *descriptor) const {
+ // Conversion for statically-known shape args is performed via
+ // `memref_reinterpret_cast`.
+ auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
+ if (shapeMemRefType.hasStaticShape())
+ return failure();
- // Extract and set allocated pointer.
- *allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr);
+ // The shape is a rank-1 tensor with unknown length.
+ Location loc = reshapeOp.getLoc();
+ MemRefDescriptor shapeDesc(adaptor.shape());
+ Value resultRank = shapeDesc.size(rewriter, loc, 0);
- // Extract and set aligned pointer.
- Value one = rewriter.create<LLVM::ConstantOp>(
- loc, int32Type, rewriter.getI32IntegerAttr(1));
- Value alignedGep = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrPtrType, elementPtrPtr, ValueRange({one}));
- *alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep);
+ // Extract address space and element type.
+ auto targetType =
+ reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
+ unsigned addressSpace = targetType.getMemorySpace();
+ Type elementType = targetType.getElementType();
+
+ // Create the unranked memref descriptor that holds the ranked one. The
+ // inner descriptor is allocated on stack.
+ auto targetDesc = UnrankedMemRefDescriptor::undef(
+ rewriter, loc, unwrap(typeConverter.convertType(targetType)));
+ targetDesc.setRank(rewriter, loc, resultRank);
+ SmallVector<Value, 4> sizes;
+ UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
+ targetDesc, sizes);
+ Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
+ loc, getVoidPtrType(), sizes.front(), llvm::None);
+ targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
+
+ // Extract pointers and offset from the source memref.
+ Value allocatedPtr, alignedPtr, offset;
+ extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(),
+ adaptor.source(), &allocatedPtr, &alignedPtr,
+ &offset);
+
+ // Set pointers and offset.
+ LLVM::LLVMType llvmElementType =
+ unwrap(typeConverter.convertType(elementType));
+ LLVM::LLVMType elementPtrPtrType =
+ llvmElementType.getPointerTo(addressSpace).getPointerTo();
+ UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
+ elementPtrPtrType, allocatedPtr);
+ UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter,
+ underlyingDescPtr,
+ elementPtrPtrType, alignedPtr);
+ UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter,
+ underlyingDescPtr, elementPtrPtrType,
+ offset);
+
+ // Use the offset pointer as base for further addressing. Copy over the new
+ // shape and compute strides. For this, we create a loop from rank-1 to 0.
+ Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
+ rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+ Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
+ rewriter, loc, typeConverter, targetSizesBase, resultRank);
+ Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
+ Value oneIndex = createIndexConstant(rewriter, loc, 1);
+ Value resultRankMinusOne =
+ rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
+
+ Block *initBlock = rewriter.getInsertionBlock();
+ LLVM::LLVMType indexType = typeConverter.getIndexType();
+ Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
+
+ Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
+ {indexType, indexType});
+
+ // Iterate over the remaining ops in initBlock and move them to condBlock.
+ BlockAndValueMapping map;
+ for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) {
+ rewriter.clone(*it, map);
+ rewriter.eraseOp(&*it);
+ }
+
+ rewriter.setInsertionPointToEnd(initBlock);
+ rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
+ condBlock);
+ rewriter.setInsertionPointToStart(condBlock);
+ Value indexArg = condBlock->getArgument(0);
+ Value strideArg = condBlock->getArgument(1);
+
+ Value zeroIndex = createIndexConstant(rewriter, loc, 0);
+ Value pred = rewriter.create<LLVM::ICmpOp>(
+ loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
+ LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
+
+ Block *bodyBlock =
+ rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
+ rewriter.setInsertionPointToStart(bodyBlock);
+
+ // Copy size from shape to descriptor.
+ LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
+ Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
+ loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
+ Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
+ UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter,
+ targetSizesBase, indexArg, size);
+
+ // Write stride value and compute next one.
+ UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter,
+ targetStridesBase, indexArg, strideArg);
+ Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
+
+ // Decrement loop counter and branch back.
+ Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
+ rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
+ condBlock);
+
+ Block *remainder =
+ rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
+
+ // Hook up the cond exit to the remainder.
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
+ llvm::None);
+
+ // Reset position to beginning of new remainder block.
+ rewriter.setInsertionPointToStart(remainder);
+
+ *descriptor = targetDesc;
+ return success();
}
};
@@ -3642,6 +3944,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
LoadOpLowering,
MemRefCastOpLowering,
MemRefReinterpretCastOpLowering,
+ MemRefReshapeOpLowering,
RankOpLowering,
StoreOpLowering,
SubViewOpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 8447474484e2..de59472f0713 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -478,9 +478,10 @@ func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
// CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
-// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
-// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]]
-// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i32) -> !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]]
+// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i64) -> !llvm.ptr<ptr<float>>
// CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
@@ -489,3 +490,73 @@ func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]]
// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]]
// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
+
+// CHECK-LABEL: @memref_reshape
+func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
+ %output = memref_reshape %input(%shape)
+ : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
+ return
+}
+// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[INPUT_TY:!.*]]
+// CHECK: [[SHAPE:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : [[SHAPE_TY:!.*]]
+// CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]]
+// CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+// CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr<i8>)>
+
+// Compute size in bytes to allocate result ranked descriptor
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
+// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
+// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : !llvm.i64
+// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
+// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x !llvm.i8
+// CHECK: llvm.insertvalue [[UNDERLYING_DESC]], [[UNRANKED_OUT_1]][1]
+
+// Set allocated, aligned pointers and offset.
+// CHECK: [[ALLOC_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[INPUT_TY]]
+// CHECK: [[ALIGN_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[INPUT_TY]]
+// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] : [[INPUT_TY]]
+// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
+// CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: llvm.store [[ALLOC_PTR]], [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]]
+// CHECK: llvm.store [[ALIGN_PTR]], [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR_PTR__:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: [[OFFSET_PTR_:%.*]] = llvm.getelementptr [[BASE_PTR_PTR__]]{{\[}}[[C2]]]
+// CHECK: [[OFFSET_PTR:%.*]] = llvm.bitcast [[OFFSET_PTR_]]
+// CHECK: llvm.store [[OFFSET]], [[OFFSET_PTR]] : !llvm.ptr<i64>
+
+// Iterate over shape operand in reverse order and set sizes and strides.
+// CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
+// CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<float>, ptr<float>, i64, i64)>>
+// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
+// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]]
+// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]]
+// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]]
+// CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[RANK_MIN_1:%.*]] = llvm.sub [[RANK]], [[C1_]] : !llvm.i64
+// CHECK: llvm.br ^bb1([[RANK_MIN_1]], [[C1_]] : !llvm.i64, !llvm.i64)
+
+// CHECK: ^bb1([[DIM:%.*]]: !llvm.i64, [[CUR_STRIDE:%.*]]: !llvm.i64):
+// CHECK: [[C0_:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: [[COND:%.*]] = llvm.icmp "sge" [[DIM]], [[C0_]] : !llvm.i64
+// CHECK: llvm.cond_br [[COND]], ^bb2, ^bb3
+
+// CHECK: ^bb2:
+// CHECK: [[SIZE_PTR:%.*]] = llvm.getelementptr [[SHAPE_IN_PTR]]{{\[}}[[DIM]]]
+// CHECK: [[SIZE:%.*]] = llvm.load [[SIZE_PTR]] : !llvm.ptr<i64>
+// CHECK: [[TARGET_SIZE_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[DIM]]]
+// CHECK: llvm.store [[SIZE]], [[TARGET_SIZE_PTR]] : !llvm.ptr<i64>
+// CHECK: [[TARGET_STRIDE_PTR:%.*]] = llvm.getelementptr [[STRIDES_PTR]]{{\[}}[[DIM]]]
+// CHECK: llvm.store [[CUR_STRIDE]], [[TARGET_STRIDE_PTR]] : !llvm.ptr<i64>
+// CHECK: [[UPDATE_STRIDE:%.*]] = llvm.mul [[CUR_STRIDE]], [[SIZE]] : !llvm.i64
+// CHECK: [[STRIDE_COND:%.*]] = llvm.sub [[DIM]], [[C1_]] : !llvm.i64
+// CHECK: llvm.br ^bb1([[STRIDE_COND]], [[UPDATE_STRIDE]] : !llvm.i64, !llvm.i64)
+
+// CHECK: ^bb3:
+// CHECK: llvm.return
diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
index 96a8ae16ae6d..5ec3b31a76b2 100644
--- a/mlir/test/mlir-cpu-runner/memref_reshape.mlir
+++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \
+// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s
@@ -39,6 +39,10 @@ func @main() -> () {
: (memref<2x3xf32>, memref<2xindex>) -> ()
call @reshape_unranked_memref_to_ranked(%input, %shape)
: (memref<2x3xf32>, memref<2xindex>) -> ()
+ call @reshape_ranked_memref_to_unranked(%input, %shape)
+ : (memref<2x3xf32>, memref<2xindex>) -> ()
+ call @reshape_unranked_memref_to_unranked(%input, %shape)
+ : (memref<2x3xf32>, memref<2xindex>) -> ()
return
}
@@ -50,9 +54,9 @@ func @reshape_ranked_memref_to_ranked(%input : memref<2x3xf32>,
%unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
- // CHECK: [0, 1],
- // CHECK: [2, 3],
- // CHECK: [4, 5]
+ // CHECK: [0, 1],
+ // CHECK: [2, 3],
+ // CHECK: [4, 5]
return
}
@@ -65,8 +69,37 @@ func @reshape_unranked_memref_to_ranked(%input : memref<2x3xf32>,
%unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
- // CHECK: [0, 1],
- // CHECK: [2, 3],
- // CHECK: [4, 5]
+ // CHECK: [0, 1],
+ // CHECK: [2, 3],
+ // CHECK: [4, 5]
+ return
+}
+
+func @reshape_ranked_memref_to_unranked(%input : memref<2x3xf32>,
+ %shape : memref<2xindex>) {
+ %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref<?xindex>
+ %output = memref_reshape %input(%dyn_size_shape)
+ : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
+
+ call @print_memref_f32(%output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+ // CHECK: [0, 1],
+ // CHECK: [2, 3],
+ // CHECK: [4, 5]
+ return
+}
+
+func @reshape_unranked_memref_to_unranked(%input : memref<2x3xf32>,
+ %shape : memref<2xindex>) {
+ %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+ %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref<?xindex>
+ %output = memref_reshape %input(%dyn_size_shape)
+ : (memref<2x3xf32>, memref<?xindex>) -> memref<*xf32>
+
+ call @print_memref_f32(%output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+ // CHECK: [0, 1],
+ // CHECK: [2, 3],
+ // CHECK: [4, 5]
return
}
More information about the Mlir-commits
mailing list