[Mlir-commits] [mlir] a3adcba - [mlir][Linalg] Implement tiling on tensors
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 6 10:53:42 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-06T17:51:11Z
New Revision: a3adcba645eec31b42ad0a1f727975c5c9c236f0
URL: https://github.com/llvm/llvm-project/commit/a3adcba645eec31b42ad0a1f727975c5c9c236f0
DIFF: https://github.com/llvm/llvm-project/commit/a3adcba645eec31b42ad0a1f727975c5c9c236f0.diff
LOG: [mlir][Linalg] Implement tiling on tensors
This revision implements tiling on tensors as described in:
https://llvm.discourse.group/t/an-update-on-linalg-on-tensors/1878/4
Differential revision: https://reviews.llvm.org/D88733
Added:
mlir/test/Dialect/Linalg/tile-tensors.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index c10a1e4f4e04..614fd8d2a7de 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -647,7 +647,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
res.reserve(nExtraOperands);
for (unsigned i = 0; i < nExtraOperands; ++i) {
res.push_back(getOperation()->getOperand(numShapedOperands + i));
- assert((res.back().getType().isSignlessIntOrIndexOrFloat()
+ assert((res.back().getType().isSignlessIntOrIndexOrFloat()
|| res.back().getType().isa<VectorType>()) &&
"expected scalar or vector type");
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e47dafc9bf52..2e566c941894 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -29,6 +29,7 @@ using LinalgLoops = SmallVector<Operation *, 4>;
struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
+ SmallVector<Value, 4> tensorResults;
};
struct TiledAndFusedLinalgOps {
@@ -371,8 +372,9 @@ struct LinalgBaseTilingPattern : public RewritePattern {
LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
+ SmallVectorImpl<Value> &tensorResults) const;
private:
/// LinalgTransformMarker handles special attribute manipulations.
@@ -390,9 +392,14 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
marker, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (failed(LinalgBaseTilingPattern::matchAndRewrite(op, rewriter)))
+ SmallVector<Value, 4> tensorResults;
+ if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
+ tensorResults)))
return failure();
- rewriter.eraseOp(op);
+ if (tensorResults.empty())
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, tensorResults);
return success();
}
};
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b4e5be58bad7..ffcac5f48aa4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -95,17 +95,17 @@ Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
unsigned consumerIdx,
OperationFolder *folder = nullptr);
-/// Returns the linearized list of all view dimensions in a `linalgOp`. Applying
-/// the inverse, concatenated loopToOperandRangeMaps to this list allows the
-/// derivation of loop ranges for any linalgOp.
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp);
+/// Returns the linearized list of all shape dimensions in a `linalgOp`.
+/// Applying the inverse, concatenated loopToOperandRangeMaps to this list
+/// allows the derivation of loop ranges for any linalgOp.
+SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp);
template <typename ConcreteOpTy>
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOpTy linalgOp) {
- return getViewSizes(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
+SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
+ return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
}
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
-/// concatenated indexing maps to the result of `getViewSizes`. Returns None if
+/// concatenated indexing maps to the result of `getShape`. Returns None if
/// the bounds computation fails.
Optional<SmallVector<Value, 4>>
getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
@@ -119,11 +119,6 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
AffineMap map, ValueRange values,
OperationFolder *folder = nullptr);
-/// Returns all the operands of `linalgOp` that are not views.
-/// Asserts that these operands are value types to allow transformations like
-/// tiling to just use the values when cloning `linalgOp`.
-SmallVector<Value, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
-
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 409f54384aca..747a83414a08 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -315,7 +315,7 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
/// source memref. This is useful to to fold a memref_cast into a consuming op
/// and implement canonicalization patterns for ops in
diff erent dialects that
/// may consume the results of memref_cast operations. Such foldable memref_cast
-/// operations are typically inserted as `view` and `subview` ops are
+/// operations are typically inserted as `view` and `subview` ops and are
/// canonicalized, to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bd45c8d667f9..abfc0001ed3e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -199,6 +199,11 @@ static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
if (isTopLevelValue(dimOp.memrefOrTensor()))
return true;
+ // Conservatively handle remaining BlockArguments as non-valid symbols.
+ // E.g. scf.for iterArgs.
+ if (dimOp.memrefOrTensor().isa<BlockArgument>())
+ return false;
+
// The dim op is also okay if its operand memref/tensor is a view/subview
// whose corresponding size is a valid symbol.
Optional<int64_t> index = dimOp.getConstantIndex();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 7b16a9197f11..585b8810fdc2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -97,7 +97,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
clonedViews.push_back(
b.create<SubViewOp>(loc, view, offsets, sizes, strides));
}
- auto operands = getAssumedNonViewOperands(op);
+ auto operands = op.getAssumedNonShapedOperands();
clonedViews.append(operands.begin(), operands.end());
Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index a9e7a8660230..9e96c8cdc691 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -508,10 +508,10 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
- SmallVector<Value, 8> sizes = getViewSizes(builder, linalgOp);
+ SmallVector<Value, 8> sizes = getShape(builder, linalgOp);
AffineMap map = concatAffineMaps(maps);
auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(),
- map, getViewSizes(builder, linalgOp));
+ map, getShape(builder, linalgOp));
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3e8e0b74c145..f7becae6e328 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -56,18 +56,17 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
// indices of newly created loops.
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
- ArrayRef<Value> allViewSizes,
- ArrayRef<Value> allTileSizes) {
+ ValueRange allShapeSizes, ValueRange allTileSizes) {
assert(allTileSizes.size() == map.getNumResults());
- // Apply `map` to get view sizes in loop order.
- auto viewSizes = applyMapToValues(b, loc, map, allViewSizes);
+ // Apply `map` to get shape sizes in loop order.
+ auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
if (isZero(tileSizes[idx - zerosCount])) {
- viewSizes.erase(viewSizes.begin() + idx - zerosCount);
+ shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
tileSizes.erase(tileSizes.begin() + idx - zerosCount);
++zerosCount;
continue;
@@ -78,10 +77,10 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
// Create a new range with the applied tile sizes.
SmallVector<Range, 4> res;
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
- res.push_back(Range{std_constant_index(0), viewSizes[idx], tileSizes[idx]});
+ res.push_back(
+ Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]});
return std::make_tuple(res, loopIndexToRangeIndex);
}
-
namespace {
// Helper visitor to determine whether an AffineExpr is tiled.
@@ -93,7 +92,7 @@ namespace {
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
//
struct TileCheck : public AffineExprVisitor<TileCheck> {
- TileCheck(ArrayRef<Value> tileSizes) : isTiled(false), tileSizes(tileSizes) {}
+ TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
void visitDimExpr(AffineDimExpr expr) {
isTiled |= !isZero(tileSizes[expr.getPosition()]);
@@ -106,7 +105,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
"nonpositive multiplying coefficient");
}
bool isTiled;
- ArrayRef<Value> tileSizes;
+ ValueRange tileSizes;
};
} // namespace
@@ -165,7 +164,6 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
static void transformIndexedGenericOpIndices(
OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
- assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
if (!indexedGenericOp)
return;
@@ -202,7 +200,7 @@ static void transformIndexedGenericOpIndices(
}
}
-static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
+static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
if (!expr)
return false;
TileCheck t(tileSizes);
@@ -210,9 +208,8 @@ static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
return t.isTiled;
}
-// Checks whether the view with index `viewIndex` within `linalgOp` varies with
-// respect to a non-zero `tileSize`.
-static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
+// Checks whether the `map varies with respect to a non-zero `tileSize`.
+static bool isTiled(AffineMap map, ValueRange tileSizes) {
if (!map)
return false;
for (unsigned r = 0; r < map.getNumResults(); ++r)
@@ -221,13 +218,11 @@ static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
return false;
}
-static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
- LinalgOp linalgOp, AffineMap map,
- ArrayRef<Value> ivs,
- ArrayRef<Value> tileSizes,
- ArrayRef<Value> allViewSizes) {
- assert(linalgOp.hasBufferSemantics() &&
- "expected linalg op with buffer semantics");
+static SmallVector<Value, 4>
+makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
+ ValueRange operands, AffineMap map, ValueRange ivs,
+ ValueRange tileSizes, ValueRange allShapeSizes) {
+ assert(operands.size() == linalgOp.getShapedOperands().size());
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
[](Value v) { return !isZero(v); })) &&
@@ -235,37 +230,34 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
using namespace edsc::op;
- auto viewSizes = applyMapToValues(b, loc, map, allViewSizes);
+ auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
// Construct (potentially temporary) mins and maxes on which to apply maps
- // that define tile subviews.
- SmallVector<Value, 8> lbs, subViewSizes;
+ // that define tile subshapes.
+ SmallVector<Value, 8> lbs, subShapeSizes;
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
bool isTiled = !isZero(tileSizes[idx]);
lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
// Before composing, we need to make range a closed interval.
- Value size = isTiled ? tileSizes[idx] : viewSizes[idx];
- subViewSizes.push_back(size - std_constant_index(1));
+ Value size = isTiled ? tileSizes[idx] : shapeSizes[idx];
+ subShapeSizes.push_back(size - std_constant_index(1));
}
auto *op = linalgOp.getOperation();
SmallVector<Value, 4> res;
res.reserve(op->getNumOperands());
- auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin();
- for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
- ++viewIndex) {
- Value view = *(viewIteratorBegin + viewIndex);
- auto viewType = view.getType().cast<MemRefType>();
- unsigned rank = viewType.getRank();
- auto mapAttr = linalgOp.indexing_maps()[viewIndex];
- auto map = mapAttr.cast<AffineMapAttr>().getValue();
- // If the view is not tiled, we can use it as is.
+ for (auto en : llvm::enumerate(operands)) {
+ Value shapedOp = en.value();
+ ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
+ unsigned rank = shapedType.getRank();
+ AffineMap map = linalgOp.getIndexingMap(en.index());
+ // If the shape is not tiled, we can use it as is.
if (!isTiled(map, tileSizes)) {
- res.push_back(view);
+ res.push_back(shapedOp);
continue;
}
- // Construct a new subview for the tile.
+ // Construct a new subview / subtensor for the tile.
SmallVector<Value, 4> offsets, sizes, strides;
offsets.reserve(rank);
sizes.reserve(rank);
@@ -273,27 +265,27 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
for (unsigned r = 0; r < rank; ++r) {
if (!isTiled(map.getSubMap({r}), tileSizes)) {
offsets.push_back(std_constant_index(0));
- sizes.push_back(std_dim(view, r));
+ sizes.push_back(std_dim(shapedOp, r));
strides.push_back(std_constant_index(1));
continue;
}
// Tiling creates a new slice at the proper index, the slice step is 1
- // (i.e. the slice view does not subsample, stepping occurs in the loop).
+ // (i.e. the op does not subsample, stepping occurs in the loop).
auto m = map.getSubMap({r});
auto offset = applyMapToValues(b, loc, m, lbs).front();
offsets.push_back(offset);
- auto closedIntSize = applyMapToValues(b, loc, m, subViewSizes).front();
+ auto closedIntSize = applyMapToValues(b, loc, m, subShapeSizes).front();
// Resulting size needs to be made half open interval again.
auto size = closedIntSize + std_constant_index(1);
- // The size of the subview should be trimmed to avoid out-of-bounds
- // accesses, unless we statically know the subview size divides the view
- // size evenly.
- int64_t viewSize = viewType.getDimSize(r);
+ // The size of the subview / subtensor should be trimmed to avoid
+ // out-of-bounds accesses, unless we statically know the subshape size
+ // divides the shape size evenly.
+ int64_t shapeSize = shapedType.getDimSize(r);
auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
- if (ShapedType::isDynamic(viewSize) || !sizeCst ||
- (viewSize % sizeCst.getValue()) != 0) {
+ if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
+ (shapeSize % sizeCst.getValue()) != 0) {
// Compute min(size, dim - offset) to avoid out-of-bounds accesses.
auto minMap = AffineMap::get(
/*dimCount=*/3, /*symbolCount=*/0,
@@ -301,7 +293,7 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
getAffineDimExpr(/*position=*/1, b.getContext()) -
getAffineDimExpr(/*position=*/2, b.getContext())},
b.getContext());
- auto d = std_dim(view, r);
+ auto d = std_dim(shapedOp, r);
size =
affine_min(b.getIndexType(), minMap, ValueRange{size, d, offset});
}
@@ -310,7 +302,12 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
strides.push_back(std_constant_index(1));
}
- res.push_back(b.create<SubViewOp>(loc, view, offsets, sizes, strides));
+ if (shapedType.isa<MemRefType>())
+ res.push_back(
+ b.create<SubViewOp>(loc, shapedOp, offsets, sizes, strides));
+ else
+ res.push_back(
+ b.create<SubTensorOp>(loc, shapedOp, offsets, sizes, strides));
}
return res;
@@ -318,7 +315,7 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
template <typename LoopTy>
static Optional<TiledLinalgOp>
-tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
+tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
const LinalgTilingOptions &options) {
auto nLoops = op.getNumLoops();
// Initial tile sizes may be too big, only take the first nLoops.
@@ -335,20 +332,20 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
}
// 1. Build the tiled loop ranges.
- auto allViewSizes = getViewSizes(b, op);
+ auto allShapeSizes = getShape(b, op);
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (asserted in the inverse calculation).
auto mapsRange = op.indexing_maps().getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
- auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
- if (!viewSizesToLoopsMap)
+ auto shapeSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
+ if (!shapeSizesToLoopsMap)
return llvm::None;
SmallVector<Range, 4> loopRanges;
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
- b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);
+ b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
SmallVector<Attribute, 4> iteratorTypes;
for (auto attr :
enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
@@ -380,29 +377,77 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
// 2. Create the tiled loops.
LinalgOp res = op;
- SmallVector<Value, 4> ivs;
+ SmallVector<Value, 4> ivs, tensorResults;
+ auto initTensors = op.getInitTensors();
GenerateLoopNest<LoopTy>::doit(
- loopRanges, /*iterArgInitValues*/ {}, iteratorTypes,
+ loopRanges, /*iterArgInitValues*/ initTensors, iteratorTypes,
[&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
ivs.assign(localIvs.begin(), localIvs.end());
- SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
- // If we have to apply a permutation to the tiled loop nest, we have to
- // reorder the induction variables This permutation is the right one
- // assuming that loopRanges have previously been permuted by
- // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation
- // of that one: (d0,d1,d2)->(d2,d0,d1)
+ // When an `interchangeVector` is present, it has been applied to the
+ // loop ranges and the iterator types. Apply its inverse to the
+ // resulting loop `ivs` to match the op definition.
+ SmallVector<Value, 4> interchangedIvs;
if (!options.interchangeVector.empty())
- ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues);
-
- auto views = makeTiledViews(b, loc, op, viewSizesToLoopsMap, ivValues,
- tileSizes, allViewSizes);
- auto operands = getAssumedNonViewOperands(op);
- views.append(operands.begin(), operands.end());
- res = op.clone(b, loc, /*resultTypes*/ {}, views);
- return scf::ValueVector{};
+ interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
+ else
+ interchangedIvs.assign(ivs.begin(), ivs.end());
+
+ assert(op.getNumInitTensors() == iterArgs.size() &&
+ "num init tensors must match number of loop iter arguments");
+ // This uses knowledge about position of the init tensor in the list
+ // of operands.
+ auto operands = llvm::to_vector<4>(op.getShapedOperands());
+ std::copy(iterArgs.begin(), iterArgs.end(),
+ operands.begin() + op.getNumInputsAndOutputBuffers());
+
+ SmallVector<Value, 4> tiledOperands =
+ makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap,
+ interchangedIvs, tileSizes, allShapeSizes);
+ auto nonShapedOperands = op.getAssumedNonShapedOperands();
+ tiledOperands.append(nonShapedOperands.begin(),
+ nonShapedOperands.end());
+
+ // If LinalgOp has results, they must all be tied to init tensors.
+ // We enforce this to ensure all tiled ops have been rewritten in
+ // "init tensor" form. This ensures tiling has anchor values into which
+ // to subtensor / subtensor_insert. Otherwise tiling would need to
+ // allocate which is not acceptable.
+ // This would not be the case with a special terminator op that
+ // generates the whole tensor (instead of inserting a subtensor). But
+ // the generator-based abstraction has other issues.
+ assert(op.getNumInitTensors() == op.getOperation()->getNumResults() &&
+ "expected same number of init tensors as number of results");
+
+ // Handle init tensor operands.
+ // This uses knowledge about position of the init tensor in the list
+ // of operands.
+ // TODO: InterfaceAdaptor ?
+ SmallVector<Type, 4> resultTensorTypes;
+ for (auto idx : llvm::seq<unsigned>(0, op.getNumInitTensors()))
+ resultTensorTypes.push_back(
+ tiledOperands[op.getNumInputsAndOutputBuffers() + idx].getType());
+
+ res = op.clone(b, loc, resultTensorTypes, tiledOperands);
+
+ // Insert a subtensor_insert for each init subtensor.
+ for (unsigned idx = 0, e = op.getNumInitTensors(); idx != e; ++idx) {
+ Value initTensor =
+ tiledOperands[op.getNumInputsAndOutputBuffers() + idx];
+ if (auto subtensor = initTensor.getDefiningOp<SubTensorOp>()) {
+ tensorResults.push_back(b.create<SubTensorInsertOp>(
+ loc, subtensor.source().getType(),
+ res.getOperation()->getResult(idx), subtensor.source(),
+ subtensor.offsets(), subtensor.sizes(), subtensor.strides(),
+ subtensor.static_offsets(), subtensor.static_sizes(),
+ subtensor.static_strides()));
+ } else {
+ tensorResults.push_back(res.getOperation()->getResult(idx));
+ }
+ }
+ return scf::ValueVector(tensorResults.begin(), tensorResults.end());
},
options.distribution);
@@ -422,7 +467,16 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
loops.push_back(nullptr);
}
}
- return TiledLinalgOp{res, loops};
+
+ // 5. Get the tensor results from the outermost loop if available. Otherwise
+ // use the previously captured `tensorResults`.
+ Operation *outermostLoop = nullptr;
+ for (Operation *loop : loops)
+ if ((outermostLoop = loop))
+ break;
+
+ return TiledLinalgOp{
+ res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
}
template <typename LoopTy>
@@ -432,7 +486,6 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
b.setInsertionPoint(op);
ScopedContext scope(b, op.getLoc());
- assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
// Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
@@ -513,7 +566,9 @@ mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
+ SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
SubViewOp::getCanonicalizationPatterns(patterns, ctx);
+ TensorCastOp::getCanonicalizationPatterns(patterns, ctx);
ViewOp::getCanonicalizationPatterns(patterns, ctx);
CanonicalizationPatternList<
#define GET_OP_LIST
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 56652cbcb527..71e3108b2b58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -111,19 +111,34 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
-LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
+LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
+ Operation *op, PatternRewriter &rewriter,
+ SmallVectorImpl<Value> &tensorResults) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
+ // If LinalgOp has results, they must all be tied to init tensors.
+ // We enforce this to ensure all tiled ops have been rewritten in
+ // "init tensor" form. This ensures tiling has anchor values into which to
+ // subtensor / subtensor_insert. Otherwise tiling would need to allocate which
+ // is not acceptable.
+ // This would not be the case with a special terminator op that generates the
+ // whole tensor (instead of inserting a subtensor). But the generator-based
+ // abstraction has other issues.
+ if (linalgOp.getNumInitTensors() != linalgOp.getOperation()->getNumResults())
+ return failure();
+
Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
if (!res)
return failure();
+ // Return relevant information to derived pattern.
+ tensorResults = res->tensorResults;
+
// New marker if specified.
marker.replaceLinalgMarker(rewriter, res->op.getOperation());
return success();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index f9ea9092d55d..3f29949ffe63 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -85,26 +85,6 @@ SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
return res;
}
-/// Returns all the operands of `linalgOp` that are not views.
-/// Asserts that these operands are value types to allow transformations like
-/// tiling to just use the values when cloning `linalgOp`.
-SmallVector<Value, 4>
-mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
- auto *op = linalgOp.getOperation();
- unsigned numViews = linalgOp.getNumInputsAndOutputs();
- unsigned nOperands = op->getNumOperands() - numViews;
- SmallVector<Value, 4> res;
- res.reserve(nOperands);
- for (unsigned i = 0; i < nOperands; ++i) {
- res.push_back(op->getOperand(numViews + i));
- auto t = res.back().getType();
- (void)t;
- assert((t.isSignlessIntOrIndexOrFloat() || t.isa<VectorType>()) &&
- "expected scalar or vector type");
- }
- return res;
-}
-
bool mlir::linalg::isParallelIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getParallelIteratorTypeName();
@@ -147,12 +127,12 @@ namespace mlir {
namespace linalg {
/// Return the linearized list of all view dimensions in a linalgOp.
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp) {
+SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
auto loc = linalgOp.getLoc();
SmallVector<Value, 8> res;
SmallVector<unsigned, 4> ranks;
- for (auto v : linalgOp.getInputsAndOutputBuffers()) {
- MemRefType t = v.getType().template cast<MemRefType>();
+ for (Value v : linalgOp.getShapedOperands()) {
+ ShapedType t = v.getType().template cast<ShapedType>();
ranks.push_back(t.getRank());
for (unsigned i = 0; i < t.getRank(); ++i)
res.push_back(builder.create<DimOp>(loc, v, i));
@@ -181,7 +161,7 @@ SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp) {
Optional<SmallVector<Value, 4>>
getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
- SmallVector<Value, 8> viewSizes = getViewSizes(builder, linalgOp);
+ SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
AffineMap invertedMap =
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
if (!invertedMap)
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
new file mode 100644
index 000000000000..b899cb3e0049
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -mlir-disable-threading=true | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @matmul_tensors(
+ %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: init(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+ init(%arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+// CHECK: return %[[TD0]] : tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list