[Mlir-commits] [mlir] 30e6033 - [mlir][Linalg] Add TensorsToBuffers support for Constant ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Oct 8 06:16:45 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-08T13:15:45Z
New Revision: 30e6033b455bfa4b888eedb2cfe808a61845ed5f
URL: https://github.com/llvm/llvm-project/commit/30e6033b455bfa4b888eedb2cfe808a61845ed5f
DIFF: https://github.com/llvm/llvm-project/commit/30e6033b455bfa4b888eedb2cfe808a61845ed5f.diff
LOG: [mlir][Linalg] Add TensorsToBuffers support for Constant ops.
This revision also inserts an end-to-end test that lowers tensors to buffers all the way to executable code on CPU.
Differential revision: https://reviews.llvm.org/D88998
Added:
mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index dcf4b5ec06cb..4e25772d5793 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -20,36 +20,39 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
"Only folds the one-trip loops from Linalg ops on tensors "
"(for testing purposes only)">
];
+ let dependentDialects = ["linalg::LinalgDialect"];
}
def LinalgFusion : FunctionPass<"linalg-fusion"> {
let summary = "Fuse operations in the linalg dialect";
let constructor = "mlir::createLinalgFusionPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
}
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
- let dependentDialects = ["AffineDialect"];
+ let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
}
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
- let dependentDialects = ["AffineDialect"];
+ let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
}
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
- let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
+ let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"];
}
def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
let summary = "Convert the Linalg operations which work on tensor-type "
"operands or results to use buffers instead";
let constructor = "mlir::createConvertLinalgOnTensorsToBuffersPass()";
+ let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"];
}
def LinalgLowerToParallelLoops
@@ -57,7 +60,7 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
- let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
+ let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
}
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@@ -69,13 +72,14 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
Option<"useAlloca", "test-use-alloca", "bool",
/*default=*/"false", "Test generation of alloca'ed buffers.">
];
+ let dependentDialects = ["linalg::LinalgDialect"];
}
def LinalgTiling : FunctionPass<"linalg-tile"> {
let summary = "Tile operations in the linalg dialect";
let constructor = "mlir::createLinalgTilingPass()";
let dependentDialects = [
- "AffineDialect", "scf::SCFDialect"
+ "AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"
];
let options = [
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
@@ -93,7 +97,7 @@ def LinalgTilingToParallelLoops
"Test generation of dynamic promoted buffers",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
- let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
+ let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
}
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 2354cc6abd89..0a8e5b679277 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -20,6 +20,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
class MLIRContext;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f8fc4dead86f..11a8cf74ddd5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -15,6 +15,7 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
def Vector_Dialect : Dialect {
let name = "vector";
@@ -1673,7 +1674,7 @@ def Vector_BitCastOp :
}
def Vector_TypeCastOp :
- Vector_Op<"type_cast", [NoSideEffect]>,
+ Vector_Op<"type_cast", [NoSideEffect, ViewLikeOpInterface]>,
Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>,
Results<(outs AnyMemRef:$result)> {
let summary = "type_cast op converts a scalar memref to a vector memref";
@@ -1711,6 +1712,8 @@ def Vector_TypeCastOp :
MemRefType getResultMemRefType() {
return getResult().getType().cast<MemRefType>();
}
+ // Implement ViewLikeOpInterface.
+ Value getViewSource() { return memref(); }
}];
let assemblyFormat = [{
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
new file mode 100644
index 000000000000..a2bd18c7a3b1
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @foo() -> tensor<4xf32> {
+ %0 = constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+func @main() {
+ %0 = call @foo() : () -> tensor<4xf32>
+
+ // Instead of relying on tensor_store which introduces aliasing, we rely on
+ // the conversion of print_memref_f32(tensor<*xf32>) to
+ // print_memref_f32(memref<*xf32>).
+ // Note that this is skipping a step and we would need at least some function
+ // attribute to declare that this conversion is valid (e.g. when we statically
+ // know that things will play nicely at the C ABI boundary).
+ %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+ call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+
+ // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+ // CHECK-SAME: rank = 1 offset = 0 sizes = [4] strides = [1] data =
+ // CHECK-NEXT: [1, 2, 3, 4]
+
+ return
+}
+
+// This gets converted to a function operating on memref<*xf32>.
+// Note that this is skipping a step and we would need at least some function
+// attribute to declare that this conversion is valid (e.g. when we statically
+// know that things will play nicely at the C ABI boundary).
+func @print_memref_f32(%ptr : tensor<*xf32>)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ac94a421903d..a95e1006e381 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1098,23 +1098,35 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
bool hasBoundedRewriteRecursion() const final { return true; }
};
-/// Returns true if the memory underlying `memRefType` has a contiguous layout.
-/// Strides are written to `strides`.
-static bool isContiguous(MemRefType memRefType,
- SmallVectorImpl<int64_t> &strides) {
+/// Returns the strides if the memory underlying `memRefType` has a contiguous
+/// static layout.
+static llvm::Optional<SmallVector<int64_t, 4>>
+computeContiguousStrides(MemRefType memRefType) {
int64_t offset;
- auto successStrides = getStridesAndOffset(memRefType, strides, offset);
- bool isContiguous = strides.empty() || strides.back() == 1;
- if (isContiguous) {
- auto sizes = memRefType.getShape();
- for (int index = 0, e = strides.size() - 2; index < e; ++index) {
- if (strides[index] != strides[index + 1] * sizes[index + 1]) {
- isContiguous = false;
- break;
- }
- }
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(memRefType, strides, offset)))
+ return None;
+ if (!strides.empty() && strides.back() != 1)
+ return None;
+ // If no layout or identity layout, this is contiguous by definition.
+ if (memRefType.getAffineMaps().empty() ||
+ memRefType.getAffineMaps().front().isIdentity())
+ return strides;
+
+ // Otherwise, we must determine contiguity form shapes. This can only ever
+ // work in static cases because MemRefType is underspecified to represent
+ // contiguous dynamic shapes in other ways than with just empty/identity
+ // layout.
+ auto sizes = memRefType.getShape();
+ for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+ if (ShapedType::isDynamic(sizes[index + 1]) ||
+ ShapedType::isDynamicStrideOrOffset(strides[index]) ||
+ ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
+ return None;
+ if (strides[index] != strides[index + 1] * sizes[index + 1])
+ return None;
}
- return succeeded(successStrides) && isContiguous;
+ return strides;
}
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
@@ -1150,9 +1162,17 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
- // Only contiguous source tensors supported atm.
- SmallVector<int64_t, 4> strides;
- if (!isContiguous(sourceMemRefType, strides))
+ // Only contiguous source buffers supported atm.
+ auto sourceStrides = computeContiguousStrides(sourceMemRefType);
+ if (!sourceStrides)
+ return failure();
+ auto targetStrides = computeContiguousStrides(targetMemRefType);
+ if (!targetStrides)
+ return failure();
+ // Only support static strides for now, regardless of contiguity.
+ if (llvm::any_of(*targetStrides, [](int64_t stride) {
+ return ShapedType::isDynamicStrideOrOffset(stride);
+ }))
return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
@@ -1181,8 +1201,8 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
desc.setSize(rewriter, loc, index, size);
- auto strideAttr =
- rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
+ auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
+ (*targetStrides)[index]);
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
desc.setStride(rewriter, loc, index, stride);
}
@@ -1223,8 +1243,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
op->getContext()))
return failure();
// Only contiguous source tensors supported atm.
- SmallVector<int64_t, 4> strides;
- if (!isContiguous(xferOp.getMemRefType(), strides))
+ auto strides = computeContiguousStrides(xferOp.getMemRefType());
+ if (!strides)
return failure();
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
@@ -1380,9 +1400,11 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
private:
enum class PrintConversion {
+ // clang-format off
None,
ZeroExt64,
SignExt64
+ // clang-format on
};
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
index 0415aeb8a1fd..dc23a3bd8599 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
@@ -18,10 +18,18 @@ namespace mlir {
template <typename ConcreteDialect>
void registerDialect(DialectRegistry ®istry);
+namespace linalg {
+class LinalgDialect;
+} // end namespace linalg
+
namespace scf {
class SCFDialect;
} // end namespace scf
+namespace vector {
+class VectorDialect;
+} // end namespace vector
+
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Linalg/Passes.h.inc"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 3282358f5f41..a9ec621efb9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
@@ -203,6 +204,64 @@ class GenericOpConverter
}
};
+// Rewrite a tensor `constant` to a vector constant folloed by a vector store
+// and a vector.type_cast.
+class TensorConstantOpConverter
+ : public BufferAssignmentOpConversionPattern<ConstantOp> {
+public:
+ using BufferAssignmentOpConversionPattern<
+ ConstantOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!op.getType().isa<RankedTensorType>())
+ return failure();
+ auto attr = op.getValue().cast<DenseElementsAttr>();
+
+ Location loc = op.getLoc();
+ MemRefType memrefType =
+ converter->convertType(op.getType()).cast<MemRefType>();
+ VectorType vectorType =
+ VectorType::get(memrefType.getShape(), memrefType.getElementType());
+
+ // vector constant takes attributes that are compatible with tensor
+ // constant.
+ Value cstVec =
+ rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
+
+ // Alloc a memref<vector<...>>, store the constant and typecast the vector
+ // away.
+ MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
+ Value alloc =
+ rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
+ rewriter.create<StoreOp>(loc, cstVec, alloc);
+ rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
+
+ return success();
+ }
+};
+
+// Rewrite a `tensor_cast` as a `memref_cast` with no layout, in the 0-memory
+// space.
+class TensorCastOpConverter
+ : public BufferAssignmentOpConversionPattern<TensorCastOp> {
+public:
+ using BufferAssignmentOpConversionPattern<
+ TensorCastOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (op.getType().hasRank())
+ return failure();
+ Type t = UnrankedMemRefType::get(op.getType().getElementType(),
+ /*memorySpace=*/0);
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
+ return success();
+ }
+};
+
/// Converts Linalg operations that work on tensor-type operands or results to
/// work on buffers.
struct ConvertLinalgOnTensorsToBuffers
@@ -213,7 +272,7 @@ struct ConvertLinalgOnTensorsToBuffers
BufferAssignmentTypeConverter converter;
// Mark all Standard operations legal.
- target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
target.addLegalOp<ModuleOp>();
target.addLegalOp<ModuleTerminatorOp>();
@@ -225,12 +284,33 @@ struct ConvertLinalgOnTensorsToBuffers
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
isLegalOperation));
- // Mark Standard Return operations illegal as long as one operand is tensor.
- target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
- return converter.isLegal(returnOp.getOperandTypes());
- });
+ // Mark operations that consume or return tensors illegal.
+ auto isLegal = [&](Operation *op) {
+ if (llvm::any_of(op->getOperandTypes(),
+ [&](Type t) { return !converter.isLegal(t); }))
+ return false;
+ if (llvm::any_of(op->getResultTypes(),
+ [&](Type t) { return !converter.isLegal(t); }))
+ return false;
+ return true;
+ };
+ target.addDynamicallyLegalOp<
+ // clang-format off
+ CallOp,
+ ConstantOp,
+ ConstantIntOp,
+ ConstantIndexOp,
+ ConstantFloatOp,
+ ReturnOp,
+ TensorCastOp
+ // clang-format on
+ >(isLegal);
// Mark the function operation illegal as long as an argument is tensor.
+ // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an
+ // externally defined symbol like an external library calls), only convert
+ // if some special attribute is set. This will allow more control of interop
+ // across ABI boundaries.
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
llvm::none_of(funcOp.getType().getResults(),
@@ -261,5 +341,11 @@ mlir::createConvertLinalgOnTensorsToBuffersPass() {
void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns(
MLIRContext *context, BufferAssignmentTypeConverter *converter,
OwningRewritePatternList *patterns) {
- patterns->insert<GenericOpConverter>(context, converter);
+ patterns->insert<
+ // clang-format off
+ GenericOpConverter,
+ TensorCastOpConverter,
+ TensorConstantOpConverter
+ // clang-format on
+ >(context, converter);
}
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 9f2c254f91e5..4f97fb836246 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -913,70 +913,75 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
// BufferAssignmentCallOpConverter
//===----------------------------------------------------------------------===//
-/// Performs the actual rewriting step.
-LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
- CallOp callOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
+namespace {
+// This class represents a mapping from a result to a list of values and some
+// results that have not yet constructed. Instead, the indices of these
+// results in the operation that will be constructed are known. They will be
+// replaced with the actual values when they are available. The order of
+// adding to this mapping is important.
+class CallOpResultMapping {
+public:
+ CallOpResultMapping() { order = 0; };
- // This class represents a mapping from a result to a list of values and some
- // results that have not yet constructed. Instead, the indices of these
- // results in the operation that will be constructed are known. They will be
- // replaced with the actual values when they are available. The order of
- // adding to this mapping is important.
- class ResultMapping {
- public:
- ResultMapping() { order = 0; };
-
- /// Add an available value to the mapping.
- void addMapping(Value value) {
- toValuesMapping.push_back({order++, value});
- }
+ /// Add an available value to the mapping.
+ void addMapping(Value value) { toValuesMapping.push_back({order++, value}); }
- /// Add the index of unavailble result value to the mapping.
- void addMapping(unsigned index) {
- toIndicesMapping.push_back({order++, index});
- }
+ /// Add the index of unavailble result value to the mapping.
+ void addMapping(unsigned index) {
+ toIndicesMapping.push_back({order++, index});
+ }
- /// This method returns the mapping values list. The unknown result values
- /// that only their indicies are available are replaced with their values.
- void getMappingValues(ValueRange valuesToReplaceIndices,
- SmallVectorImpl<Value> &values) {
- // Append available values to the list.
- SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
- toValuesMapping.end());
- // Replace the indices with the actual values.
- llvm::for_each(
- toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
- assert(entry.second < valuesToReplaceIndices.size() &&
- "The value index is out of range.");
- res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
- });
- // Sort the values based on their adding orders.
- llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
- const std::pair<unsigned, Value> &v2) {
- return v1.first < v2.first;
- });
- // Fill the values.
- llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
- values.push_back(entry.second);
- });
- }
+ /// This method returns the mapping values list. The unknown result values
+ /// that only their indicies are available are replaced with their values.
+ void getMappingValues(ValueRange valuesToReplaceIndices,
+ SmallVectorImpl<Value> &values) {
+ // Append available values to the list.
+ SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
+ toValuesMapping.end());
+ // Replace the indices with the actual values.
+ llvm::for_each(
+ toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
+ assert(entry.second < valuesToReplaceIndices.size() &&
+ "The value index is out of range.");
+ res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
+ });
+ // Sort the values based on their adding orders.
+ llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
+ const std::pair<unsigned, Value> &v2) {
+ return v1.first < v2.first;
+ });
+ // Fill the values.
+ llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
+ values.push_back(entry.second);
+ });
+ }
- private:
- /// Keeping the inserting order of mapping values.
- int order;
+private:
+ /// Keeping the inserting order of mapping values.
+ int order;
- /// Containing the mapping values with their inserting orders.
- SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
+ /// Containing the mapping values with their inserting orders.
+ SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
- /// Containing the indices of result values with their inserting orders.
- SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
- };
+ /// Containing the indices of result values with their inserting orders.
+ SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
+};
+} // namespace
+
+/// Performs the actual rewriting step.
+LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
+ CallOp callOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
Location loc = callOp.getLoc();
OpBuilder builder(callOp);
SmallVector<Value, 2> newOperands;
+ // TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
+ // to an externally defined symbol like an external library calls), only
+ // convert if some special attribute is set.
+ // This will allow more control of interop across ABI boundaries.
+
// Create the operands list of the new `CallOp`. It unpacks the decomposable
// values if a decompose callback function has been provided by the user.
for (auto operand : operands) {
@@ -989,7 +994,7 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
// Create the new result types for the new `CallOp` and a mapping from the old
// result to new value(s).
SmallVector<Type, 2> newResultTypes;
- SmallVector<ResultMapping, 4> mappings;
+ SmallVector<CallOpResultMapping, 4> mappings;
mappings.resize(callOp.getNumResults());
for (auto result : llvm::enumerate(callOp.getResults())) {
SmallVector<Type, 2> originTypes;
diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
index 4339b33a2379..7d714092cb7c 100644
--- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
+++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
@@ -123,3 +123,53 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
// CHECK: linalg.copy(%[[OUT_BUF_2]], %[[OUT_2]]) : [[TYPE]], [[TYPE]]
// CHECK: dealloc %[[OUT_BUF_2]] : [[TYPE]]
// CHECK: return
+
+// -----
+
+func @foo() -> tensor<4xf32> {
+// CHECK-LABEL: func @foo(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: memref<4xf32>) {
+
+ %0 = constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32>
+// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<4xf32>
+// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<vector<4xf32>>
+// CHECK-NEXT: store %[[CST]], %[[ALLOC]][] : memref<vector<4xf32>>
+// CHECK-NEXT: %[[RES:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<4xf32>> to memref<4xf32>
+
+ return %0 : tensor<4xf32>
+// CHECK-NEXT: linalg.copy(%[[RES]], %[[A]]) : memref<4xf32>, memref<4xf32>
+// CHECK-NEXT: dealloc %[[ALLOC]] : memref<vector<4xf32>>
+// CHECK-NEXT: return
+}
+
+func @bar() {
+// CHECK-LABEL: func @bar() {
+
+ %0 = call @foo() : () -> tensor<4xf32>
+// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<4xf32>) -> ()
+
+ // Instead of relying on tensor_store which introduces aliasing, we rely on
+ // the conversion of print_memref_f32(tensor<*xf32>) to
+ // print_memref_f32(memref<*xf32>).
+ // Note that this is skipping a step and we would need at least some function
+ // attribute to declare that this conversion is valid (e.g. when we statically
+ // know that things will play nicely at the C ABI boundary).
+ %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+// CHECK-NEXT: %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] :
+// CHECK-SAME: memref<4xf32> to memref<*xf32>
+
+ call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+// CHECK-NEXT: call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> ()
+
+ return
+// CHECK-NEXT: dealloc %[[ALLOC]] : memref<4xf32>
+// CHECK-NEXT: return
+}
+
+// This gets converted to a function operating on memref<*xf32>.
+// Note that this is skipping a step and we would need at least some function
+// attribute to declare that this conversion is valid (e.g. when we statically
+// know that things will play nicely at the C ABI boundary).
+func @print_memref_f32(%ptr : tensor<*xf32>)
+// CHECK-LABEL: func @print_memref_f32(memref<*xf32>)
More information about the Mlir-commits
mailing list