[Mlir-commits] [mlir] c1fd430 - [mlir] Add basic support for dynamic tensor results in TensorToBuffers.cpp.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Oct 8 02:56:11 PDT 2020
Author: Alexander Belyaev
Date: 2020-10-08T11:55:42+02:00
New Revision: c1fd4305b68500c754a7ce6a86fe297c36e21d3b
URL: https://github.com/llvm/llvm-project/commit/c1fd4305b68500c754a7ce6a86fe297c36e21d3b
DIFF: https://github.com/llvm/llvm-project/commit/c1fd4305b68500c754a7ce6a86fe297c36e21d3b.diff
LOG: [mlir] Add basic support for dynamic tensor results in TensorToBuffers.cpp.
The simplest case is when the indexing maps are DimIds in every component. This covers cwise ops.
Also:
* Expose populateConvertLinalgOnTensorsToBuffersPatterns in Transforms.h
* Expose emitLoopRanges in Transforms.h
Differential Revision: https://reviews.llvm.org/D88781
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2e566c941894..395db396dadc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -16,6 +16,9 @@
#include "llvm/ADT/SmallBitVector.h"
namespace mlir {
+
+class BufferAssignmentTypeConverter;
+
namespace linalg {
struct LinalgFusionOptions;
@@ -45,6 +48,12 @@ void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
ArrayRef<int64_t> tileSizes);
+/// Populates the given list with patterns to convert Linalg operations on
+/// tensors to buffers.
+void populateConvertLinalgOnTensorsToBuffersPatterns(
+ MLIRContext *context, BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns);
+
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
@@ -246,6 +255,16 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
LinalgPromotionOptions options,
OperationFolder *folder = nullptr);
+/// Creates a number of ranges equal to the number of dimensions in the `map`.
+/// The returned ranges correspond to the loop ranges, in the proper order, for
+/// which new loops will be created.
+/// The function supports only maps that are invertible and have results of type
+/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
+/// It expects a non-inverted, concatenated map and last values in
+/// allViewSizes will be applied to the symbols in the map if it contains any.
+SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+ ValueRange viewSizes);
+
/// Emit a suitable vector form for a Linalg op with fully static shape.
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 9e96c8cdc691..b95469d8a955 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -58,77 +58,6 @@ static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
}
-/// Creates a number of ranges equal to the number of dimensions in the `map`.
-/// The returned ranges correspond to the loop ranges, in the proper order, for
-/// which new loops will be created.
-/// The function supports only maps that are invertible and have results of type
-/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
-/// It expects a non-inverted, concatenated map and last values in
-/// allViewSizes will be applied to the symbols in the map if it contains any.
-static SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange viewSizes) {
- unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
- unsigned numSym = map.getNumSymbols();
- assert(viewSizes.size() == numRes + numSym &&
- "viewSizes must contain sizes of all views and values for symbols");
- SmallVector<Range, 4> res(numDims);
- for (unsigned idx = 0; idx < numRes; ++idx) {
- auto result = map.getResult(idx);
- if (auto d = result.dyn_cast<AffineDimExpr>()) {
- if (res[d.getPosition()].offset)
- continue;
- res[d.getPosition()] =
- Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
- }
-
- // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
- // then the bounds are:
- // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
- // where size(n) is applied to the symbol s.
- // This is done statically now.
- if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
- auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>();
- auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>();
- if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add ||
- lhs.getKind() != AffineExprKind::Add ||
- rhs.getKind() != mlir::AffineExprKind::Mul)
- continue;
-
- auto m = lhs.getLHS().dyn_cast<AffineDimExpr>();
- auto n = lhs.getRHS().dyn_cast<AffineDimExpr>();
- auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>();
- auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>();
- if (!m || !n || !fDiv || !minusOne ||
- fDiv.getKind() != AffineExprKind::FloorDiv ||
- fDiv.getLHS().getKind() != AffineExprKind::SymbolId ||
- fDiv.getRHS().getKind() != AffineExprKind::Constant)
- continue;
-
- auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>();
- if (minusOne.getValue() != -1)
- continue;
-
- int mPos = m.getPosition();
- AffineExpr one = getAffineConstantExpr(1, s.getContext());
- AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext());
- // Construction of upper bound (size(m) + s floordiv 2 - s + 1).
- AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
- AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv);
- AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr);
- SmallVector<Value, 8> values(viewSizes.begin(),
- viewSizes.begin() + numDims);
- values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end());
- values.push_back(viewSizes[mPos]);
- // Construction of the lower bound (s floordiv 2).
- Value from = applyMapToValues(b, loc, fromMap, values).front();
- Value to = applyMapToValues(b, loc, toMap, values).front();
- res[mPos] = Range{from, to, std_constant_index(1)};
- }
- }
- return res;
-}
-
template <typename IndexedValueType, typename OpType>
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value, 8>> indexing,
@@ -708,6 +637,70 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
}
+SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange viewSizes) {
+ unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
+ unsigned numSym = map.getNumSymbols();
+ assert(viewSizes.size() == numRes + numSym &&
+ "viewSizes must contain sizes of all views and values for symbols");
+ SmallVector<Range, 4> res(numDims);
+ for (unsigned idx = 0; idx < numRes; ++idx) {
+ auto result = map.getResult(idx);
+ if (auto d = result.dyn_cast<AffineDimExpr>()) {
+ if (res[d.getPosition()].offset)
+ continue;
+ res[d.getPosition()] =
+ Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
+ }
+
+ // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
+ // then the bounds are:
+ // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
+ // where size(n) is applied to the symbol s.
+ // This is done statically now.
+ if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
+ auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>();
+ auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>();
+ if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add ||
+ lhs.getKind() != AffineExprKind::Add ||
+ rhs.getKind() != mlir::AffineExprKind::Mul)
+ continue;
+
+ auto m = lhs.getLHS().dyn_cast<AffineDimExpr>();
+ auto n = lhs.getRHS().dyn_cast<AffineDimExpr>();
+ auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>();
+ auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!m || !n || !fDiv || !minusOne ||
+ fDiv.getKind() != AffineExprKind::FloorDiv ||
+ fDiv.getLHS().getKind() != AffineExprKind::SymbolId ||
+ fDiv.getRHS().getKind() != AffineExprKind::Constant)
+ continue;
+
+ auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>();
+ if (minusOne.getValue() != -1)
+ continue;
+
+ int mPos = m.getPosition();
+ AffineExpr one = getAffineConstantExpr(1, s.getContext());
+ AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext());
+ // Construction of upper bound (size(m) + s floordiv 2 - s + 1).
+ AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
+ AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv);
+ AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr);
+ SmallVector<Value, 8> values(viewSizes.begin(),
+ viewSizes.begin() + numDims);
+ values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end());
+ values.push_back(viewSizes[mPos]);
+ // Construction of the lower bound (s floordiv 2).
+ Value from = applyMapToValues(b, loc, fromMap, values).front();
+ Value to = applyMapToValues(b, loc, toMap, values).front();
+ res[mPos] = Range{from, to, std_constant_index(1)};
+ }
+ }
+ return res;
+}
+
/// Emits a loop nest with the proper body for `op`.
template <typename LoopTy>
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index b714a1f6c642..3282358f5f41 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -14,14 +14,119 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/BufferPlacement.h"
-using namespace mlir;
-
namespace {
+
+using namespace ::mlir;
+using namespace ::mlir::linalg;
+
+SmallVector<Range, 4>
+computeLoopRanges(Location loc, linalg::GenericOp linalgOp, OpBuilder *b) {
+ auto indexingMaps = llvm::to_vector<4>(
+ linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
+ auto inputIndexingMaps =
+ llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
+
+ mlir::edsc::ScopedContext scope(*b, loc);
+ return emitLoopRanges(scope.getBuilderRef(), loc,
+ concatAffineMaps(inputIndexingMaps),
+ getShape(*b, linalgOp));
+}
+
+Value maybeConvertToIndex(Location loc, Value val, OpBuilder *b) {
+ if (val.getType().isIndex())
+ return val;
+ return b->create<IndexCastOp>(loc, val, b->getIndexType());
+}
+
+LogicalResult allocateBuffersForResults(Location loc,
+ linalg::GenericOp linalgOp,
+ linalg::GenericOpAdaptor &adaptor,
+ SmallVectorImpl<Value> *resultBuffers,
+ OpBuilder *b) {
+ // Lazily compute loopRanges.
+ SmallVector<Range, 4> loopRanges;
+
+ // Allocate a buffer for every tensor result.
+ for (auto en : llvm::enumerate(linalgOp.getResultTypes())) {
+ size_t resultIndex = en.index();
+ Type resultType = en.value();
+
+ auto tensorType = resultType.dyn_cast<RankedTensorType>();
+ if (tensorType == nullptr) {
+ linalgOp.emitOpError()
+ << "tensor to buffer conversion expects ranked tensor results";
+ return failure();
+ }
+ auto tensorShape = tensorType.getShape();
+ auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
+
+ // Allocate buffers for init tensors that are assumed to fold onto the first
+ // results.
+ // TODO: update this assumption because the reality is more complex
+ // under linalg on tensor based transformations.
+ bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors();
+ if (foldedInitTensor) {
+ // Dealing with an init tensor requires distinguishing between 1-use
+ // and many-use cases which would create aliasing and WAR hazards.
+ Value initTensor = linalgOp.getInitTensor(resultIndex);
+ Value initBuffer = adaptor.init_tensors()[resultIndex];
+ if (initTensor.hasOneUse()) {
+ resultBuffers->push_back(initBuffer);
+ continue;
+ }
+ SmallVector<Value, 4> dynOperands;
+ for (auto dim : llvm::enumerate(tensorShape)) {
+ if (dim.value() == TensorType::kDynamicSize) {
+ dynOperands.push_back(b->create<DimOp>(loc, initTensor, dim.index()));
+ }
+ }
+ auto alloc = b->create<AllocOp>(loc, memrefType, dynOperands);
+ b->create<linalg::CopyOp>(loc, initBuffer, alloc);
+ resultBuffers->push_back(alloc);
+ continue;
+ }
+
+ // Allocate buffers for statically-shaped results.
+ if (memrefType.hasStaticShape()) {
+ resultBuffers->push_back(b->create<AllocOp>(loc, memrefType));
+ continue;
+ }
+
+ // Perform a naive shape inference for the dynamically-shaped results.
+ // Extract the required element out of the vector.
+ SmallVector<Value, 4> dynOperands;
+ auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
+ for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
+ if (loopRanges.empty())
+ loopRanges = computeLoopRanges(loc, linalgOp, b);
+
+ if (shapeElement.value() != ShapedType::kDynamicSize)
+ continue;
+
+ AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
+ switch (expr.getKind()) {
+ case AffineExprKind::DimId: {
+ int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
+ Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
+ dynOperands.push_back(size);
+ break;
+ }
+ default:
+ return failure();
+ }
+ }
+ resultBuffers->push_back(b->create<AllocOp>(loc, memrefType, dynOperands));
+ }
+ return success();
+}
+
/// A pattern to convert Generic Linalg operations which work on tensors to
/// use buffers. A buffer is allocated using BufferAssignmentPlacer for
/// each operation result. BufferPlacement pass should be later used to move
@@ -34,10 +139,10 @@ class GenericOpConverter
linalg::GenericOp>::BufferAssignmentOpConversionPattern;
LogicalResult
- matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
+ matchAndRewrite(linalg::GenericOp linalgOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- linalg::GenericOpAdaptor adaptor(operands,
- op.getOperation()->getAttrDictionary());
+ linalg::GenericOpAdaptor adaptor(
+ operands, linalgOp.getOperation()->getAttrDictionary());
// All inputs need to be turned into buffers first. Until then, bail out.
if (llvm::any_of(adaptor.inputs(),
@@ -50,93 +155,54 @@ class GenericOpConverter
[](Value in) { return !in.getType().isa<MemRefType>(); }))
return failure();
- Location loc = op.getLoc();
- SmallVector<Value, 2> newOutputBuffers;
- newOutputBuffers.reserve(op.getNumOutputs());
- newOutputBuffers.append(adaptor.output_buffers().begin(),
- adaptor.output_buffers().end());
-
- // Update all types to memref types.
- // Assume the init tensors fold onto the first results.
- // TODO: update this assumption because the reality is more complex under
- // linalg on tensor based transformations.
- for (auto en : llvm::enumerate(op.getResultTypes())) {
- auto type = en.value().cast<ShapedType>();
- if (!type.hasStaticShape())
- return rewriter.notifyMatchFailure(
- op, "dynamic shapes not currently supported");
- auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
- bool foldedInitTensor = en.index() < op.getNumInitTensors();
- if (foldedInitTensor) {
- // Dealing with an init tensor requires distinguishing between 1-use
- // and many-use cases which would create aliasing and WAR hazards.
- Value initTensor = op.getInitTensor(en.index());
- Value initBuffer = adaptor.init_tensors()[en.index()];
- if (initTensor.hasOneUse()) {
- newOutputBuffers.push_back(initBuffer);
- continue;
- }
- auto alloc = rewriter.create<AllocOp>(loc, memrefType);
- rewriter.create<linalg::CopyOp>(loc, initBuffer, alloc);
- newOutputBuffers.push_back(alloc);
- } else {
- auto alloc = rewriter.create<AllocOp>(loc, memrefType);
- newOutputBuffers.push_back(alloc);
- }
+ Location loc = linalgOp.getLoc();
+ SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
+ adaptor.output_buffers().end());
+
+ if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
+ &newOutputBuffers, &rewriter))) {
+ linalgOp.emitOpError()
+ << "Failed to allocate buffers for tensor results.";
+ return failure();
}
// Generate a new linalg operation that works on buffers.
- auto linalgOp = rewriter.create<linalg::GenericOp>(
+ auto newLinalgOp = rewriter.create<linalg::GenericOp>(
loc,
- /*resultTensorTypes=*/ArrayRef<Type>{},
+ /*resultTensorTypes=*/llvm::None,
/*inputs=*/adaptor.inputs(),
/*outputBuffers=*/newOutputBuffers,
- /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(),
- op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr());
+ /*initTensors=*/llvm::None, linalgOp.indexing_maps(),
+ linalgOp.iterator_types(), linalgOp.docAttr(),
+ linalgOp.library_callAttr(), linalgOp.symbol_sourceAttr());
// Create a new block in the region of the new Generic Op.
- Block &oldBlock = op.getRegion().front();
- Region &newRegion = linalgOp.region();
+ Block *oldBlock = linalgOp.getBody();
+ Region &newRegion = newLinalgOp.region();
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
- oldBlock.getArgumentTypes());
-
- // Add the result arguments that do not come from init_tensors to the new
- // block.
- // TODO: update this assumption because the reality is more complex under
- // linalg on tensor based transformations.
- for (Value v :
- ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size()))
+ oldBlock->getArgumentTypes());
+
+ // Add the result arguments to the new block.
+ for (Value v : newOutputBuffers)
newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
- for (unsigned i = 0; i < oldBlock.getNumArguments(); i++)
- mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i));
+ mapping.map(oldBlock->getArguments(), newBlock->getArguments());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
- for (auto &op : oldBlock.getOperations()) {
+ for (auto &op : oldBlock->getOperations()) {
Operation *clonedOp = rewriter.clone(op, mapping);
mapping.map(op.getResults(), clonedOp->getResults());
}
// Replace the results of the old op with the new output buffers.
- rewriter.replaceOp(op, newOutputBuffers);
+ rewriter.replaceOp(linalgOp, newOutputBuffers);
return success();
}
};
-/// Populate the given list with patterns to convert Linalg operations on
-/// tensors to buffers.
-static void populateConvertLinalgOnTensorsToBuffersPattern(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
- populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter,
- patterns);
- patterns->insert<GenericOpConverter>(context, converter);
-}
-
/// Converts Linalg operations that work on tensor-type operands or results to
/// work on buffers.
struct ConvertLinalgOnTensorsToBuffers
@@ -176,8 +242,11 @@ struct ConvertLinalgOnTensorsToBuffers
BufferAssignmentTypeConverter::AppendToArgumentsList);
OwningRewritePatternList patterns;
- populateConvertLinalgOnTensorsToBuffersPattern(&context, &converter,
- &patterns);
+ populateConvertLinalgOnTensorsToBuffersPatterns(&context, &converter,
+ &patterns);
+ populateWithBufferAssignmentOpConversionPatterns<
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, &converter,
+ &patterns);
if (failed(applyFullConversion(this->getOperation(), target, patterns)))
this->signalPassFailure();
}
@@ -188,3 +257,9 @@ std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertLinalgOnTensorsToBuffersPass() {
return std::make_unique<ConvertLinalgOnTensorsToBuffers>();
}
+
+void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns(
+ MLIRContext *context, BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns) {
+ patterns->insert<GenericOpConverter>(context, converter);
+}
diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
index 654a13fca743..4339b33a2379 100644
--- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
+++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
@@ -2,11 +2,13 @@
#map0 = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @multiple_results_generic_op
-func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
- %0, %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]}
- ins(%arg0 : tensor<4xf32>) {
- ^bb0(%gen_arg1: f32):
+// CHECK-LABEL: func @multiple_results
+func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+ %0, %1 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel"]
+ } ins(%arg0 : tensor<4xf32>) {
+ ^bb0(%gen_arg1: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1, %tmp1 : f32, f32
} -> tensor<4xf32>, tensor<4xf32>
@@ -34,15 +36,20 @@ func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tenso
// CHECK-LABEL: func @chained_operations
func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> {
- %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
- ins(%arg0 : tensor<4xf32>) {
- ^bb0(%gen_arg1: f32):
+ %0 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel"]
+ } ins(%arg0 : tensor<4xf32>) {
+ ^bb0(%gen_arg1: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1 : f32
} -> tensor<4xf32>
- %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
- ins(%0 : tensor<4xf32>) {
- ^bb0(%gen_arg2: f32):
+
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel"]
+ } ins(%0 : tensor<4xf32>) {
+ ^bb0(%gen_arg2: f32):
%tmp2 = exp %gen_arg2 : f32
linalg.yield %tmp2 : f32
} -> tensor<4xf32>
@@ -73,6 +80,46 @@ func @no_linalg_op(%arg0: f32) -> (f32, f32) {
%0 = mulf %arg0, %arg0 : f32
return %0, %0 : f32, f32
}
-// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]])
-// CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]]
-// CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]]
+// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]])
+// CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]]
+// CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]]
+
+// -----
+
+#map_2d = affine_map<(d0, d1) -> (d0, d1)>
+#map_2d_inv = affine_map<(d0, d1) -> (d1, d0)>
+
+func @dynamic_results(%arg0: tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0, %1 = linalg.generic {
+ indexing_maps = [#map_2d, #map_2d, #map_2d_inv],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%gen_arg1: f32):
+ %tmp1 = exp %gen_arg1 : f32
+ linalg.yield %tmp1, %tmp1 : f32, f32
+ } -> tensor<?x?xf32>, tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: func @dynamic_results
+// CHECK-SAME: (%[[INPUT:.*]]: [[TYPE:.*]], %[[OUT_1:.*]]: [[TYPE]], %[[OUT_2:.*]]: [[TYPE]]) {
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[DIM_0:.*]] = dim %[[INPUT]], %[[C0]] : [[TYPE]]
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[DIM_1:.*]] = dim %[[INPUT]], %[[C1]] : [[TYPE]]
+// CHECK: %[[OUT_BUF_1:.*]] = alloc(%[[DIM_0]], %[[DIM_1]]) : [[TYPE]]
+// CHECK: %[[OUT_BUF_2:.*]] = alloc(%[[DIM_1]], %[[DIM_0]]) : [[TYPE]]
+
+// CHECK: linalg.generic {indexing_maps = [#map0, #map0, #map1], {{.*}}}
+// CHECK-SAME: ins(%[[INPUT]] : [[TYPE]])
+// CHECK-SAME: outs(%[[OUT_BUF_1]], %[[OUT_BUF_2]] : [[TYPE]], [[TYPE]]) {
+
+// CHECK: linalg.copy(%[[OUT_BUF_1]], %[[OUT_1]]) : [[TYPE]], [[TYPE]]
+// CHECK: dealloc %[[OUT_BUF_1]] : [[TYPE]]
+// CHECK: linalg.copy(%[[OUT_BUF_2]], %[[OUT_2]]) : [[TYPE]], [[TYPE]]
+// CHECK: dealloc %[[OUT_BUF_2]] : [[TYPE]]
+// CHECK: return
More information about the Mlir-commits
mailing list