[Mlir-commits] [mlir] 01c4418 - [mlir][Linalg] NFC - Factor out Linalg functionality for shape and loop bounds computation
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 23 02:21:21 PST 2020
Author: Nicolas Vasilache
Date: 2020-11-23T10:17:18Z
New Revision: 01c4418544b7934f8216a6616562bbaf34dc6979
URL: https://github.com/llvm/llvm-project/commit/01c4418544b7934f8216a6616562bbaf34dc6979
DIFF: https://github.com/llvm/llvm-project/commit/01c4418544b7934f8216a6616562bbaf34dc6979.diff
LOG: [mlir][Linalg] NFC - Factor out Linalg functionality for shape and loop bounds computation
This revision refactors code used in various Linalg transformations and makes it a first class citizen to the LinalgStructureOpInterface. This is in preparation to allowing more advanced Linalg behavior but is otherwise NFC.
Differential revision: https://reviews.llvm.org/D91863
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 713fb192f073..f8002279132f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -11,12 +11,13 @@
#include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -32,10 +33,29 @@ namespace mlir {
namespace linalg {
class ConvOp;
+class LinalgOp;
class PoolingMaxOp;
class PoolingMinOp;
class PoolingSumOp;
+// TOFO: allow an extra ValueRange to specify an indexing and allow
+// non-hyperrectangular shapes.
+using LoopRangeBuilder =
+ std::function<SmallVector<Range, 4>(OpBuilder &, Location)>;
+
+/// Returns the values obtained by applying `map` to the list of values.
+SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
+ AffineMap map, ValueRange values);
+
+/// Provide a very simple inference procedure to build the loop ranges from the
+/// op and its operands. This only works with permutation affine maps and
+/// patterns of the form `(m, n)[s] -> (m + n - s floordiv 2)`.
+/// A more advanced Tensor-Comprehension like inference is possible but has
+/// proven to be ambiguous in unfavorable case.
+/// As a consequence, we relax the default behavior very conservatively and
+/// provide an op-specified hook so that Linalg ops may override the behavior.
+LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
+
using ReassociationIndices = SmallVector<int64_t, 2>;
using ReassociationExprs = SmallVector<AffineExpr, 2>;
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 0373bf3f6adf..6c7da083d7af 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -765,6 +765,59 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
}]
>,
+ //===------------------------------------------------------------------===//
+ // Linalg generalization hooks.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Hook to provide a custom AffineMap used to compute all the operand
+ subshapes given loop bounds. This is used to answer the question: "given
+ an iteration space over the codomain, what are the subshapes of the
+ operands involved in the computation".
+ The default behavior is to just concatenate all the indexing maps.
+ A custom AffineMap allows providing a map that can be used to
+ compute subshapes even in cases where the concatenation of indexing maps
+ (i.e. the data traversal order) is not a simple permutation of the loop
+ traversal order. It is then possible to define ops with skewed data
+ traversal order for which we can still easily compute hyperrectangular
+ loop bounds and subviews.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getLoopsToShapesMap",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto r = $_op.indexing_maps().template getAsRange<AffineMapAttr>();
+ auto maps = llvm::to_vector<8>(
+ llvm::map_range(r, [](AffineMapAttr a) { return a.getValue(); }));
+ return concatAffineMaps(maps);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Hook to provide a custom AffineMap used to construct the
+ hyperrectangular loop iteration space given all the operand subshapes.
+ This is used to answer the question:
+ "Given a list of operand ranges, what is the subportion of the iteration
+ space involved in the computation".
+ This is the inverse problem of `getLoopsToShapesMap`.
+ Return the empty AffineMap when such an AffineMap cannot be constructed.
+ The default behavior is based on a very simple inference procedure that
+ only works with permutation affine maps.
+ A more advanced Tensor-Comprehension like inference is possible but has
+ proven to be ambiguous in unfavorable case.
+ A safer and more robust alternative is to allow each each op to define
+ its own AffineMap.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getShapesToLoopsMap",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return inversePermutation(getLoopsToShapesMap());
+ }]
+ >,
+
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
@@ -818,6 +871,15 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
];
let extraClassDeclaration = [{
+ /// Return the flat list of all operand dimension sizes in the order they
+ /// appear in the operands.
+ SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
+
+ /// Create the loop ranges to materialize the computation over the current
+ /// operands. This is done by applying `getShapesToLoopsMap` to
+ /// `createFlatListOfOperandDims`.
+ SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);
+
/// Returns all the operands past the inputs, output_buffers and
/// init_tensors operands. Asserts that these operands are value types to
/// allow transformations like tiling to just use the values when cloning
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d531a1e343a..b7cfa6f023a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -256,16 +256,6 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
LinalgPromotionOptions options,
OperationFolder *folder = nullptr);
-/// Creates a number of ranges equal to the number of dimensions in the `map`.
-/// The returned ranges correspond to the loop ranges, in the proper order, for
-/// which new loops will be created.
-/// The function supports only maps that are invertible and have results of type
-/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
-/// It expects a non-inverted, concatenated map and last values in
-/// allViewSizes will be applied to the symbols in the map if it contains any.
-SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
- ValueRange viewSizes);
-
/// Emit a suitable vector form for a Linalg op with fully static shape.
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index f5669e383368..a6b8afdce9d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -105,36 +105,16 @@ Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
Operation *consumer,
unsigned consumerIdx);
-/// Returns the linearized list of all shape dimensions in a `linalgOp`.
-/// Applying the inverse, concatenated loopToOperandRangeMaps to this list
-/// allows the derivation of loop ranges for any linalgOp.
-SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp);
-template <typename ConcreteOpTy>
-SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
- return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
-}
-
/// Like `getShape`, but only returns statically-known information, without
/// generating any new IR. For each shape dimension, returns >=0 if that
/// dimension is statically known, or -1 otherwise.
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
-/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
-/// concatenated indexing maps to the result of `getShape`. Returns None if
-/// the bounds computation fails.
-Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
- LinalgOp linalgOp);
-
/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
/// Returns None if inverting the concatenated indexing map fails. Returns -1
/// for non-statically-known loop ranges.
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
-
-/// Returns the values obtained by applying `map` to the list of values.
-SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map, ValueRange values);
-
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b2ac41027b7b..188e00b53940 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -11,18 +11,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
@@ -34,6 +30,132 @@
using namespace mlir;
using namespace mlir::linalg;
+/// Fully compose map with operands and canonicalize the result.
+/// Return the `createOrFold`'ed AffineApply op.
+static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange operandsRef) {
+ SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
+ fullyComposeAffineMapAndOperands(&map, &operands);
+ canonicalizeMapAndOperands(&map, &operands);
+ return b.createOrFold<AffineApplyOp>(loc, map, operands);
+}
+
+SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange values) {
+ SmallVector<Value, 4> res;
+ res.reserve(map.getNumResults());
+ unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
+ // For each `expr` in `map`, applies the `expr` to the values extracted from
+ // ranges. If the resulting application can be folded into a Value, the
+ // folding occurs eagerly.
+ for (auto expr : map.getResults()) {
+ AffineMap map = AffineMap::get(numDims, numSym, expr);
+ res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
+ }
+ return res;
+}
+
+SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
+ Location loc) {
+ SmallVector<Value, 4> res;
+ SmallVector<unsigned, 4> ranks;
+ for (Value v : getShapedOperands()) {
+ ShapedType t = v.getType().template cast<ShapedType>();
+ ranks.push_back(t.getRank());
+ for (unsigned i = 0; i < t.getRank(); ++i)
+ res.push_back(b.create<DimOp>(loc, v, i));
+ }
+
+ // TODO: drop the following once symbol_source is deleted.
+ auto attr = getAttrOfType<IntegerAttr>("symbol_source");
+ if (!attr)
+ return res;
+
+ // Find the correct position for inserting values for symbols.
+ unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
+ for (unsigned idx = 0, e = attr.getInt(); idx < e; idx++)
+ symbolsPos += ranks[idx];
+
+ // Append the end of the value list that corresponds to the
+ // values mapping to symbols. Since inside concatenated map symbols
+ // are repeated we have to repeat the sizes as well.
+
+ // Reserve is mandatory to avoid a potential undefined behavior with
+ // pushing back to smallvector from itself.
+ res.reserve(res.size() + ranks.size() * numSymb);
+ for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
+ for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
+ res.push_back(res[symbolsPos + idx2]);
+ return res;
+}
+
+SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
+ AffineMap map = getLoopsToShapesMap();
+ unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
+ // TODO: drop numSym once symbol_source is deleted.
+ unsigned numSym = map.getNumSymbols();
+ auto viewSizes = createFlatListOfOperandDims(b, loc);
+ SmallVector<Range, 4> res(numDims);
+ Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
+ Value oneVal = b.create<ConstantIndexOp>(loc, 1);
+ for (unsigned idx = 0; idx < numRes; ++idx) {
+ auto result = map.getResult(idx);
+ if (auto d = result.dyn_cast<AffineDimExpr>()) {
+ if (res[d.getPosition()].offset)
+ continue;
+ res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
+ }
+
+ // TODO: drop the following once symbol_source is deleted.
+ // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
+ // then the bounds are:
+ // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
+ // where size(n) is applied to the symbol s.
+ // This is done statically now.
+ if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
+ auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>();
+ auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>();
+ if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add ||
+ lhs.getKind() != AffineExprKind::Add ||
+ rhs.getKind() != mlir::AffineExprKind::Mul)
+ continue;
+
+ auto m = lhs.getLHS().dyn_cast<AffineDimExpr>();
+ auto n = lhs.getRHS().dyn_cast<AffineDimExpr>();
+ auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>();
+ auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!m || !n || !fDiv || !minusOne ||
+ fDiv.getKind() != AffineExprKind::FloorDiv ||
+ !fDiv.getLHS().isa<AffineSymbolExpr>() ||
+ !fDiv.getRHS().isa<AffineConstantExpr>())
+ continue;
+
+ auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>();
+ if (minusOne.getValue() != -1)
+ continue;
+
+ int mPos = m.getPosition();
+ AffineExpr one = getAffineConstantExpr(1, s.getContext());
+ AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext());
+ // Construction of upper bound (size(m) + s floordiv 2 - s + 1).
+ AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
+ AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv);
+ AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr);
+ SmallVector<Value, 8> values(viewSizes.begin(),
+ viewSizes.begin() + numDims);
+ values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end());
+ values.push_back(viewSizes[mPos]);
+ // Construction of the lower bound (s floordiv 2).
+ Value from = applyMapToValues(b, loc, fromMap, values).front();
+ Value to = applyMapToValues(b, loc, toMap, values).front();
+ res[mPos] = Range{from, to, oneVal};
+ }
+ }
+ return res;
+}
+
/// Forward declarations.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributes(
@@ -504,11 +626,15 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
<< idx << " results to match view rank: " << view;
}
+ // TODO: symbol_source prevents us to just write:
+ // if (!op.getShapeToLoopsMap())
+ // return op.emitOpError("expected the shape-to-loops map to be non-null");
+ //
+ // Update when symbol_source is deleted.
auto concatMap = concatAffineMaps(indexingMaps);
// TODO: Bound inference for maps with symbols
if (!concatMap.getNumSymbols() && !inversePermutation(concatMap))
- return op.emitOpError("expected the concatenation of maps in indexing_map "
- "to be invertible");
+ return op.emitOpError("expected the shape-to-loops map to be non-null");
if (failed(AnnotationsVerifier<GenericOpType>::verify(op)))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 43891631ca2c..d3d0ff40f124 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -14,26 +14,13 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
using namespace ::mlir;
using namespace ::mlir::linalg;
-static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp,
- OpBuilder &b) {
- auto indexingMaps = llvm::to_vector<4>(
- linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
- auto inputIndexingMaps =
- llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
-
- mlir::edsc::ScopedContext scope(b, loc);
- return emitLoopRanges(scope.getBuilderRef(), loc,
- concatAffineMaps(inputIndexingMaps),
- getShape(b, linalgOp));
-}
-
static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
if (val.getType().isIndex())
return val;
@@ -97,11 +84,9 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
if (loopRanges.empty())
- loopRanges = computeLoopRanges(loc, linalgOp, b);
-
+ loopRanges = linalgOp.createLoopRanges(b, loc);
if (shapeElement.value() != ShapedType::kDynamicSize)
continue;
-
AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
switch (expr.getKind()) {
case AffineExprKind::DimId: {
@@ -284,7 +269,7 @@ class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
/// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
/// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
-/// tensor_to_memref and tensor_load are inserted automatically by the
+/// tensor_to_memref and tensor_load are inserted automatically by the
/// conversion infra:
/// ```
/// %sv = subview %dest [offsets][sizes][strides]
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 6c46dbf07acc..c9132e93d13e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -526,14 +526,7 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
auto linalgOp = cast<LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
- auto mapsRange =
- linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
- auto maps = llvm::to_vector<8>(
- llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
- SmallVector<Value, 8> sizes = getShape(builder, linalgOp);
- AffineMap map = concatAffineMaps(maps);
- auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(),
- map, getShape(builder, linalgOp));
+ auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),
@@ -669,70 +662,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
-SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange viewSizes) {
- unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
- unsigned numSym = map.getNumSymbols();
- assert(viewSizes.size() == numRes + numSym &&
- "viewSizes must contain sizes of all views and values for symbols");
- SmallVector<Range, 4> res(numDims);
- for (unsigned idx = 0; idx < numRes; ++idx) {
- auto result = map.getResult(idx);
- if (auto d = result.dyn_cast<AffineDimExpr>()) {
- if (res[d.getPosition()].offset)
- continue;
- res[d.getPosition()] =
- Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
- }
-
- // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
- // then the bounds are:
- // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
- // where size(n) is applied to the symbol s.
- // This is done statically now.
- if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
- auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>();
- auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>();
- if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add ||
- lhs.getKind() != AffineExprKind::Add ||
- rhs.getKind() != mlir::AffineExprKind::Mul)
- continue;
-
- auto m = lhs.getLHS().dyn_cast<AffineDimExpr>();
- auto n = lhs.getRHS().dyn_cast<AffineDimExpr>();
- auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>();
- auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>();
- if (!m || !n || !fDiv || !minusOne ||
- fDiv.getKind() != AffineExprKind::FloorDiv ||
- fDiv.getLHS().getKind() != AffineExprKind::SymbolId ||
- fDiv.getRHS().getKind() != AffineExprKind::Constant)
- continue;
-
- auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>();
- if (minusOne.getValue() != -1)
- continue;
-
- int mPos = m.getPosition();
- AffineExpr one = getAffineConstantExpr(1, s.getContext());
- AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext());
- // Construction of upper bound (size(m) + s floordiv 2 - s + 1).
- AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
- AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv);
- AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr);
- SmallVector<Value, 8> values(viewSizes.begin(),
- viewSizes.begin() + numDims);
- values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end());
- values.push_back(viewSizes[mPos]);
- // Construction of the lower bound (s floordiv 2).
- Value from = applyMapToValues(b, loc, fromMap, values).front();
- Value to = applyMapToValues(b, loc, toMap, values).front();
- res[mPos] = Range{from, to, std_constant_index(1)};
- }
- }
- return res;
-}
-
/// Emits a loop nest with the proper body for `op`.
template <typename LoopTy>
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index b4809ac4d7c4..197bdbc1a99f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -332,13 +332,8 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
}
// 1. Build the tiled loop ranges.
- auto allShapeSizes = getShape(b, op);
- // The flattened loopToOperandRangesMaps is expected to be an invertible
- // permutation map (asserted in the inverse calculation).
- auto mapsRange = op.indexing_maps().getAsRange<AffineMapAttr>();
- auto maps = llvm::to_vector<8>(
- llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
- auto shapeSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
+ auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
+ AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
if (!shapeSizesToLoopsMap)
return llvm::None;
@@ -367,10 +362,11 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
continue;
interchangeVector.push_back(it->second);
}
+ // Interchange vector is guaranteed to be a permutation,
+ // `inversePermutation` must succeed.
invPermutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, b.getContext()));
- if (!invPermutationMap)
- return llvm::None;
+ assert(invPermutationMap);
applyPermutationToVector(loopRanges, interchangeVector);
applyPermutationToVector(iteratorTypes, interchangeVector);
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index c9769476baec..43f40163da81 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -57,31 +57,6 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
return llvm::None;
}
-static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange operandsRef) {
- SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
- fullyComposeAffineMapAndOperands(&map, &operands);
- canonicalizeMapAndOperands(&map, &operands);
- return b.createOrFold<AffineApplyOp>(loc, map, operands);
-}
-
-SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange values) {
- SmallVector<Value, 4> res;
- res.reserve(map.getNumResults());
- unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
- // For each `expr` in `map`, applies the `expr` to the values extracted from
- // ranges. If the resulting application can be folded into a Value, the
- // folding occurs eagerly.
- for (auto expr : map.getResults()) {
- AffineMap map = AffineMap::get(numDims, numSym, expr);
- res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
- }
- return res;
-}
-
bool mlir::linalg::isParallelIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getParallelIteratorTypeName();
@@ -123,39 +98,6 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
namespace mlir {
namespace linalg {
-/// Return the linearized list of all view dimensions in a linalgOp.
-SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
- auto loc = linalgOp.getLoc();
- SmallVector<Value, 8> res;
- SmallVector<unsigned, 4> ranks;
- for (Value v : linalgOp.getShapedOperands()) {
- ShapedType t = v.getType().template cast<ShapedType>();
- ranks.push_back(t.getRank());
- for (unsigned i = 0; i < t.getRank(); ++i)
- res.push_back(builder.create<DimOp>(loc, v, i));
- }
-
- auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
- if (attr) {
- // Find the correct position for inserting values for symbols.
- unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
- for (unsigned idx = 0; idx < attr.getInt(); idx++)
- symbolsPos += ranks[idx];
-
- // Append the end of the value list that corresponds to the
- // values mapping to symbols. Since inside concatinated map symbols are
- // repeated we have to repeat the sizes as well.
-
- // Reserve is mandatory to avoid a potential undefined behavior with
- // pushing back to smallvector from itself.
- res.reserve(res.size() + ranks.size() * numSymb);
- for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
- for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
- res.push_back(res[symbolsPos + idx2]);
- }
- return res;
-}
-
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
SmallVector<int64_t, 8> res;
for (Value v : linalgOp.getShapedOperands()) {
@@ -165,16 +107,6 @@ SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
return res;
}
-Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
- LinalgOp linalgOp) {
- SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
- AffineMap invertedMap =
- inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
- if (!invertedMap)
- return {};
- return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
-}
-
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
AffineMap invertedMap =
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index dcfafdc4d27a..76cd3470a12a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -131,7 +131,7 @@ func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(o
// -----
func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
+ // expected-error @+1 {{expected the shape-to-loops map to be non-null}}
linalg.generic {
indexing_maps = [
affine_map<(i, j) -> (i + j)>,
More information about the Mlir-commits
mailing list