[Mlir-commits] [mlir] 1029c82 - [mlir][Linalg] NFC - Extract a standalone LinalgInterfaces
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Feb 3 23:25:42 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-04T07:19:38Z
New Revision: 1029c82c1e199d654059fad9e3fbef6e68501863
URL: https://github.com/llvm/llvm-project/commit/1029c82c1e199d654059fad9e3fbef6e68501863
DIFF: https://github.com/llvm/llvm-project/commit/1029c82c1e199d654059fad9e3fbef6e68501863.diff
LOG: [mlir][Linalg] NFC - Extract a standalone LinalgInterfaces
This separation improves the layering and paves the way for more interfaces coming up in the future.
Differential revision: https://reviews.llvm.org/D95941
Added:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index 09db72806565..14ee4ea6968a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -45,8 +45,8 @@ add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen)
add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen)
add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen)
-set(LLVM_TARGET_DEFINITIONS LinalgStructuredOpsInterface.td)
-mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls)
-mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRLinalgStructuredOpsInterfaceIncGen)
-add_dependencies(mlir-headers MLIRLinalgStructuredOpsInterfaceIncGen)
+set(LLVM_TARGET_DEFINITIONS LinalgInterfaces.td)
+mlir_tablegen(LinalgInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRLinalgInterfacesIncGen)
+add_dependencies(mlir-headers MLIRLinalgInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
new file mode 100644
index 000000000000..e4fddd594580
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -0,0 +1,44 @@
+//===- LinalgInterface.h - Linalg operations interfaces -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interfaces for Linalg operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
+#define MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+
+namespace mlir {
+namespace linalg {
+
+/// Returns the values obtained by applying `map` to the list of values.
+SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
+ AffineMap map, ValueRange values);
+
+namespace detail {
+
+/// Verify that `op` conforms to the invariants of StructuredOpInterface
+LogicalResult verifyStructuredOpInterface(Operation *op);
+
+} // namespace detail
+} // namespace linalg
+} // namespace mlir
+
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
similarity index 98%
rename from mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
rename to mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 7f3839a02b2f..a38b04ca16b2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1,4 +1,4 @@
-//===- LinalgStructuredInterface.td- Linalg StructuredIfce -*- tablegen -*-===//
+//===- LinalgInterfaces.td - Linalg Interfaces Declaration -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
//
-// This is the definition file for the structured interface for Linalg ops.
+// This is the definition file for the structured interface sfor Linalg ops.
//
//===----------------------------------------------------------------------===//
-#ifndef LINALG_IR_STRUCTURED_OPS_INTERFACE
-#define LINALG_IR_STRUCTURED_OPS_INTERFACE
+#ifndef LINALG_IR_LINALGINTERFACES
+#define LINALG_IR_LINALGINTERFACES
-include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/IR/OpBase.td"
// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
// interface.
@@ -33,10 +33,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodName=*/"getNumPayloadInductionVariables",
/*args=*/(ins),
/*methodBody=*/"",
- /*defaultImplementation=*/[{
- return isa<IndexedGenericOp>(this->getOperation()) ?
- $_op.getNumLoops() : 0;
- }]
+ /*defaultImplementation=*/""
>,
//===------------------------------------------------------------------===//
// Loop types handling.
@@ -570,7 +567,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
unsigned bbArgNumber =
- getNumPayloadInductionVariables() + opOperand->getOperandNumber();
+ $_op.getNumPayloadInductionVariables() + opOperand->getOperandNumber();
// Safeguard against the named linalg ops that are manually defined and
// that only support buffer semantics: we should not be there.
// Such ops have an empty regionBuilder and are not constructed with a
@@ -1117,4 +1114,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
}
-#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE
+#endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 4075ddd12117..f75e3010d3c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -42,10 +42,6 @@ class PoolingSumOp;
using LoopRangeBuilder =
std::function<SmallVector<Range, 4>(OpBuilder &, Location)>;
-/// Returns the values obtained by applying `map` to the list of values.
-SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map, ValueRange values);
-
/// Provide a very simple inference procedure to build the loop ranges from the
/// op and its operands. This only works with permutation affine maps and
/// patterns of the form `(m, n)[s] -> (m + n - s floordiv 2)`.
@@ -122,7 +118,7 @@ namespace linalg {
class IndexedGenericOp;
} // namespace linalg
} // namespace mlir
-#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c88e1201f84b..8988a3a11efd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -15,7 +15,7 @@
#define LINALG_STRUCTURED_OPS
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
-include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -25,13 +25,22 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
- LinalgStructuredInterface])> {}
+ LinalgStructuredInterface])> {
+ code structuredOpsBaseDecls = [{
+ // Return the number of induction variables in the basic block. This should
+ // always be 0 for index-free linalg ops. For IndexedGeneric, this must be
+ // equal to numLoops.
+ unsigned getNumPayloadInductionVariables() {
+ return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
+ }
+ }];
+}
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
: LinalgStructuredBase_Op<mnemonic,
!listconcat(props, [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>])> {
- code libraryCallName = [{
+ code structuredOpsDecls = structuredOpsBaseDecls # [{
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
@@ -110,7 +119,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
- let extraClassDeclaration = libraryCallName # [{
+ let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return getOperands().take_front(); }
ValueRange outputs() { return getOperands().take_back(); }
@@ -155,7 +164,7 @@ def FillOp : LinalgStructured_Op<"fill", []> {
let arguments = (ins AnyShaped:$output,
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
let results = (outs Optional<AnyRankedTensor>:$result);
- let extraClassDeclaration = libraryCallName # [{
+ let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return {}; }
ValueRange outputs() { return getOperands().take_front(); }
@@ -232,7 +241,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
for both low and high in each of the dimensions, if not specified.
}];
- code commonUtils = libraryCallName # [{
+ code commonUtils = structuredOpsDecls # [{
int64_t getStride(unsigned i) {
assert(i < getNumWindowLoops());
if (!strides().hasValue()) return 1;
@@ -497,7 +506,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
OptionalAttr<ArrayAttr>:$sparse);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
- let extraClassDeclaration = [{
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
return SmallVector<StringRef, 8>{
getDocAttrName(),
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index 3ed79a554b31..8522919bacb3 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalg
+ LinalgInterfaces.cpp
LinalgOps.cpp
LinalgTypes.cpp
@@ -6,9 +7,9 @@ add_mlir_dialect_library(MLIRLinalg
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
DEPENDS
+ MLIRLinalgInterfacesIncGen
MLIRLinalgOpsIncGen
MLIRLinalgStructuredOpsIncGen
- MLIRLinalgStructuredOpsInterfaceIncGen
LINK_LIBS PUBLIC
MLIRAffine
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
new file mode 100644
index 000000000000..f9b17dd38fe0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -0,0 +1,294 @@
+//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Include the definitions of the copy operation interface.
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
+
+/// Fully compose map with operands and canonicalize the result.
+/// Return the `createOrFold`'ed AffineApply op.
+static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange operandsRef) {
+ SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
+ fullyComposeAffineMapAndOperands(&map, &operands);
+ canonicalizeMapAndOperands(&map, &operands);
+ return b.createOrFold<AffineApplyOp>(loc, map, operands);
+}
+
+SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange values) {
+ SmallVector<Value, 4> res;
+ res.reserve(map.getNumResults());
+ unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
+ // For each `expr` in `map`, applies the `expr` to the values extracted from
+ // ranges. If the resulting application can be folded into a Value, the
+ // folding occurs eagerly.
+ for (auto expr : map.getResults()) {
+ AffineMap map = AffineMap::get(numDims, numSym, expr);
+ res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
+ }
+ return res;
+}
+
+SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
+ Location loc) {
+ SmallVector<Value, 4> res;
+ for (Value v : getShapedOperands()) {
+ ShapedType t = v.getType().template cast<ShapedType>();
+ for (unsigned i = 0, e = t.getRank(); i < e; ++i)
+ res.push_back(b.create<DimOp>(loc, v, i));
+ }
+ return res;
+}
+
+SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
+ AffineMap map = getLoopsToShapesMap();
+ unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
+ auto viewSizes = createFlatListOfOperandDims(b, loc);
+ SmallVector<Range, 4> res(numDims);
+ Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
+ Value oneVal = b.create<ConstantIndexOp>(loc, 1);
+ for (unsigned idx = 0; idx < numRes; ++idx) {
+ auto result = map.getResult(idx);
+ if (auto d = result.dyn_cast<AffineDimExpr>()) {
+ if (res[d.getPosition()].offset)
+ continue;
+ res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
+ }
+ }
+ return res;
+}
+
+/// Visitor to check if any of the given set of positions from AffineDimExprs
+/// are used within an AffineExpr.
+struct HasAffineDimExprVisitor
+ : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
+ HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
+ : positions(positions) {}
+
+ bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
+ return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
+ }
+
+ bool visitDimExpr(AffineDimExpr dimExpr) {
+ return positions.count(dimExpr.getPosition());
+ }
+
+ bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
+
+ bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
+
+private:
+ llvm::SmallSet<unsigned, 4> positions;
+};
+
+Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
+ Location loc,
+ unsigned resultIdx,
+ unsigned dim) {
+ // An example that helps understand the logic below.
+ // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
+ // We want to express the shape of dim 0 of O in terms of shape of the inputs.
+ // This is achieved as follows.
+ // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
+ // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
+ // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
+ // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
+ // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
+ AffineMap loopsToShapesMap = getLoopsToShapesMap();
+
+ // Find the position in the above map that represents the shape of the
+ // result:dim being inferred.
+ Optional<unsigned> resultDimSubMapPos =
+ getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
+ if (!resultDimSubMapPos)
+ return {};
+
+ /// From loopsToShapesMap extract the submap that represents the shape of the
+ /// (resultIdx, dim) needed
+ AffineMap loopToResultDimShapeMap =
+ loopsToShapesMap.getSubMap(*resultDimSubMapPos);
+ AffineMap operandShapesToResultDimMap =
+ loopToResultDimShapeMap.compose(getShapesToLoopsMap());
+
+ // Check that the result dim map does not contain the positions corresponding
+ // to the outputs.
+ llvm::SmallSet<unsigned, 4> outputDims;
+ unsigned outputDimPosStart =
+ getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
+ unsigned outputDimPosEnd =
+ getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
+ getOutputOpOperands()
+ .back()
+ .get()
+ .getType()
+ .cast<ShapedType>()
+ .getRank() -
+ 1)
+ .getValue();
+ llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
+ [&outputDims](unsigned dim) { outputDims.insert(dim); });
+ HasAffineDimExprVisitor checkDimExpr(outputDims);
+ if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
+ return llvm::None;
+ return applyMapToValues(b, loc, operandShapesToResultDimMap,
+ createFlatListOfOperandDims(b, loc))[0];
+}
+
+LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ // Expect at least one shaped operand.
+ // This means an op that constructs a tensor out of indices cannot be a
+ // LinalgOp at the moment. For now this will have to be a special op until we
+ // have output shape operands that are not tensors.
+ auto nShapedOperands = linalgOp.getNumShapedOperands();
+ if (nShapedOperands == 0)
+ return linalgOp.emitOpError("expected at least 1 Shaped operand");
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands)))
+ return failure();
+ // Should have at least one output tensor per result tensor.
+ // Can also have outbut buffers that do not correspond to results.
+ if (op->getNumResults() > linalgOp.getNumOutputTensors())
+ return op->emitError("unexpected #results > #outputs");
+
+ // All shaped operands must be indexed.
+ if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands())
+ return linalgOp.emitOpError("expected the number of indexing_map (")
+ << linalgOp.indexing_maps().size()
+ << ") to be equal to the number of shaped operands ("
+ << linalgOp.getNumShapedOperands() << ")";
+
+ SmallVector<AffineMap, 4> indexingMaps;
+ indexingMaps.reserve(linalgOp.indexing_maps().size());
+ for (auto en : llvm::enumerate(linalgOp.indexing_maps())) {
+ auto idx = en.index();
+ auto m = en.value().template cast<AffineMapAttr>().getValue();
+ indexingMaps.push_back(m); // Save reference to map for further checks.
+ auto shapedValue = linalgOp.getShapedType(idx);
+
+ // Symbols disallowed.
+ if (m.getNumSymbols() != 0)
+ return linalgOp.emitOpError("unexpected symbols in indexing_map #")
+ << idx;
+
+ // Domain must be consistent.
+ auto nLoops = linalgOp.getNumLoops();
+ if (m.getNumDims() != nLoops)
+ return linalgOp.emitOpError("expected indexing_map #")
+ << idx << " to have " << nLoops
+ << " dim(s) to match the number of loops";
+
+ if (m.getNumResults() != shapedValue.getRank())
+ return linalgOp.emitOpError("expected shaped value rank (")
+ << shapedValue.getRank()
+ << ") to match the result rank of indexing_map #" << idx << " ("
+ << m.getNumResults() << ")";
+ }
+
+ SmallVector<AffineExpr, 4> redDims;
+ linalgOp.getReductionDims(redDims);
+
+ // Simplifying assumption: either full tensor or full buffer mode.
+ // This allows simpler verification of output operands vs result types
+ // without premature tracking of which operand is what in mixed-mode.
+ // TODO: relax when mixed-mode needs to pass verification.
+ if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0)
+ return op->emitError("expected output operands to all have tensor type or "
+ "all have buffer type");
+
+ for (auto it :
+ llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) {
+ if (!std::get<0>(it).get().getType().isa<RankedTensorType>())
+ continue;
+ if (std::get<0>(it).get().getType() != std::get<1>(it))
+ return op->emitError("expected type of operand #")
+ << std::get<0>(it).getOperandNumber() << " ("
+ << std::get<0>(it).get().getType() << ")"
+ << " to match type of corresponding result (" << std::get<1>(it)
+ << ")";
+ }
+
+ // Output tensor indexing map may not depend on reduction indices.
+ for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) {
+ AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber());
+ for (auto expr : outputMap.getResults()) {
+ for (auto dim : redDims) {
+ unsigned pos = dim.cast<AffineDimExpr>().getPosition();
+ if (expr.isFunctionOfDim(pos)) {
+ std::string exprStr;
+ {
+ llvm::raw_string_ostream os(exprStr);
+ os << expr;
+ }
+ return op->emitError(
+ "unexpected output tensor expression in indexing map #")
+ << (opOperand.getOperandNumber() - linalgOp.getNumInputs())
+ << " a.k.a '" << exprStr
+ << "' is function of reduction iterator 'd" << pos << "'";
+ }
+ }
+ }
+ }
+
+ // Named ops that are defined manually have a region builder but no region at
+ // this time. Assume the region is well-formed by specification.
+ // TODO: use linalg-ods-gen for all ops when we have enough expressive power.
+ if (linalgOp->getNumRegions() == 0) {
+ assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region");
+ return success();
+ }
+
+ auto ®ion = linalgOp->getRegion(0);
+ if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region))
+ return op->emitOpError("expected 1 region with 1 block");
+
+ if (!linalgOp.getShapesToLoopsMap())
+ return op->emitOpError("expected the shape-to-loops map to be non-null");
+
+ // Simplifying assumption: bbargs match 1-1 with shape operands elemental
+ // types.
+ // TODO: once ranked shape types are plugged in, we may want to drop the
+ // corresponding bbargs, that can never be read from. This will be subject to
+ // consistency discussions (i.e. what to do with output tensors whose bbarg is
+ // not used).
+ Block &block = linalgOp->getRegion(0).front();
+ unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables();
+
+ if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments())
+ return op->emitError("expected as many non-induction variable region "
+ "arguments as the number of shaped operands");
+
+ // Note: the number and type of yield values are checked in the YieldOp.
+ for (unsigned i = 0; i < numBBIvs; ++i)
+ if (!block.getArgument(i).getType().isIndex())
+ return op->emitOpError("expected index block argument #") << i;
+
+ unsigned idx = 0;
+ for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(),
+ block.getArguments().drop_front(numBBIvs))) {
+ if (std::get<0>(it).getElementType() != std::get<1>(it).getType())
+ return op->emitError("expected type of bb argument #")
+ << (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")"
+ << " to match element type of corresponding shaped operand ("
+ << std::get<0>(it).getElementType() << ")";
+ ++idx;
+ }
+
+ return success();
+}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7d2685f8166a..7a720d3e68bc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,138 +32,6 @@
using namespace mlir;
using namespace mlir::linalg;
-/// Fully compose map with operands and canonicalize the result.
-/// Return the `createOrFold`'ed AffineApply op.
-static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange operandsRef) {
- SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
- fullyComposeAffineMapAndOperands(&map, &operands);
- canonicalizeMapAndOperands(&map, &operands);
- return b.createOrFold<AffineApplyOp>(loc, map, operands);
-}
-
-SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange values) {
- SmallVector<Value, 4> res;
- res.reserve(map.getNumResults());
- unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
- // For each `expr` in `map`, applies the `expr` to the values extracted from
- // ranges. If the resulting application can be folded into a Value, the
- // folding occurs eagerly.
- for (auto expr : map.getResults()) {
- AffineMap map = AffineMap::get(numDims, numSym, expr);
- res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
- }
- return res;
-}
-
-SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
- Location loc) {
- SmallVector<Value, 4> res;
- for (Value v : getShapedOperands()) {
- ShapedType t = v.getType().template cast<ShapedType>();
- for (unsigned i = 0, e = t.getRank(); i < e; ++i)
- res.push_back(b.create<DimOp>(loc, v, i));
- }
- return res;
-}
-
-SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
- AffineMap map = getLoopsToShapesMap();
- unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
- auto viewSizes = createFlatListOfOperandDims(b, loc);
- SmallVector<Range, 4> res(numDims);
- Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
- Value oneVal = b.create<ConstantIndexOp>(loc, 1);
- for (unsigned idx = 0; idx < numRes; ++idx) {
- auto result = map.getResult(idx);
- if (auto d = result.dyn_cast<AffineDimExpr>()) {
- if (res[d.getPosition()].offset)
- continue;
- res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
- }
- }
- return res;
-}
-
-/// Visitor to check if any of the given set of positions from AffineDimExprs
-/// are used within an AffineExpr.
-struct HasAffineDimExprVisitor
- : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
- HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
- : positions(positions) {}
-
- bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
- return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
- }
-
- bool visitDimExpr(AffineDimExpr dimExpr) {
- return positions.count(dimExpr.getPosition());
- }
-
- bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
-
- bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
-
-private:
- llvm::SmallSet<unsigned, 4> positions;
-};
-
-Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
- Location loc,
- unsigned resultIdx,
- unsigned dim) {
- // An example that helps understand the logic below.
- // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
- // We want to express the shape of dim 0 of O in terms of shape of the inputs.
- // This is achieved as follows.
- // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
- // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
- // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
- // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
- // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
- AffineMap loopsToShapesMap = getLoopsToShapesMap();
-
- // Find the position in the above map that represents the shape of the
- // result:dim being inferred.
- Optional<unsigned> resultDimSubMapPos =
- getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
- if (!resultDimSubMapPos)
- return {};
-
- /// From loopsToShapesMap extract the submap that represents the shape of the
- /// (resultIdx, dim) needed
- AffineMap loopToResultDimShapeMap =
- loopsToShapesMap.getSubMap(*resultDimSubMapPos);
- AffineMap operandShapesToResultDimMap =
- loopToResultDimShapeMap.compose(getShapesToLoopsMap());
-
- // Check that the result dim map does not contain the positions corresponding
- // to the outputs.
- llvm::SmallSet<unsigned, 4> outputDims;
- unsigned outputDimPosStart =
- getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
- unsigned outputDimPosEnd =
- getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
- getOutputOpOperands()
- .back()
- .get()
- .getType()
- .cast<ShapedType>()
- .getRank() -
- 1)
- .getValue();
- llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
- [&outputDims](unsigned dim) { outputDims.insert(dim); });
- HasAffineDimExprVisitor checkDimExpr(outputDims);
- if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
- return llvm::None;
- return applyMapToValues(b, loc, operandShapesToResultDimMap,
- createFlatListOfOperandDims(b, loc))[0];
-}
-
/// Forward declarations.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
@@ -215,11 +83,6 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
-///////////////////// Operations defined with Tablegen /////////////////////////
-// For such operations that do not correspond to library calls (i.e. defined in
-// LinalgOps.td), we define an overloaded `print` function and a
-// parse`className` function.
-
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
@@ -471,148 +334,6 @@ void IndexedGenericOp::getEffects(
getInputBuffers(), getOutputBuffers());
}
-LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
- LinalgOp linalgOp = cast<LinalgOp>(op);
- // Expect at least one shaped operand.
- // This means an op that constructs a tensor out of indices cannot be a
- // LinalgOp at the moment. For now this will have to be a special op until we
- // have output shape operands that are not tensors.
- auto nShapedOperands = linalgOp.getNumShapedOperands();
- if (nShapedOperands == 0)
- return linalgOp.emitOpError("expected at least 1 Shaped operand");
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands)))
- return failure();
- // Should have at least one output tensor per result tensor.
- // Can also have outbut buffers that do not correspond to results.
- if (op->getNumResults() > linalgOp.getNumOutputTensors())
- return op->emitError("unexpected #results > #outputs");
-
- // All shaped operands must be indexed.
- if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands())
- return linalgOp.emitOpError("expected the number of indexing_map (")
- << linalgOp.indexing_maps().size()
- << ") to be equal to the number of shaped operands ("
- << linalgOp.getNumShapedOperands() << ")";
-
- SmallVector<AffineMap, 4> indexingMaps;
- indexingMaps.reserve(linalgOp.indexing_maps().size());
- for (auto en : llvm::enumerate(linalgOp.indexing_maps())) {
- auto idx = en.index();
- auto m = en.value().template cast<AffineMapAttr>().getValue();
- indexingMaps.push_back(m); // Save reference to map for further checks.
- auto shapedValue = linalgOp.getShapedType(idx);
-
- // Symbols disallowed.
- if (m.getNumSymbols() != 0)
- return linalgOp.emitOpError("unexpected symbols in indexing_map #")
- << idx;
-
- // Domain must be consistent.
- auto nLoops = linalgOp.getNumLoops();
- if (m.getNumDims() != nLoops)
- return linalgOp.emitOpError("expected indexing_map #")
- << idx << " to have " << nLoops
- << " dim(s) to match the number of loops";
-
- if (m.getNumResults() != shapedValue.getRank())
- return linalgOp.emitOpError("expected shaped value rank (")
- << shapedValue.getRank()
- << ") to match the result rank of indexing_map #" << idx << " ("
- << m.getNumResults() << ")";
- }
-
- SmallVector<AffineExpr, 4> redDims;
- linalgOp.getReductionDims(redDims);
-
- // Simplifying assumption: either full tensor or full buffer mode.
- // This allows simpler verification of output operands vs result types
- // without premature tracking of which operand is what in mixed-mode.
- // TODO: relax when mixed-mode needs to pass verification.
- if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0)
- return op->emitError("expected output operands to all have tensor type or "
- "all have buffer type");
-
- for (auto it :
- llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) {
- if (!std::get<0>(it).get().getType().isa<RankedTensorType>())
- continue;
- if (std::get<0>(it).get().getType() != std::get<1>(it))
- return op->emitError("expected type of operand #")
- << std::get<0>(it).getOperandNumber() << " ("
- << std::get<0>(it).get().getType() << ")"
- << " to match type of corresponding result (" << std::get<1>(it)
- << ")";
- }
-
- // Output tensor indexing map may not depend on reduction indices.
- for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) {
- AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber());
- for (auto expr : outputMap.getResults()) {
- for (auto dim : redDims) {
- unsigned pos = dim.cast<AffineDimExpr>().getPosition();
- if (expr.isFunctionOfDim(pos)) {
- std::string exprStr;
- {
- llvm::raw_string_ostream os(exprStr);
- os << expr;
- }
- return op->emitError(
- "unexpected output tensor expression in indexing map #")
- << (opOperand.getOperandNumber() - linalgOp.getNumInputs())
- << " a.k.a '" << exprStr
- << "' is function of reduction iterator 'd" << pos << "'";
- }
- }
- }
- }
-
- // Named ops that are defined manually have a region builder but no region at
- // this time. Assume the region is well-formed by specification.
- // TODO: use linalg-ods-gen for all ops when we have enough expressive power.
- if (linalgOp->getNumRegions() == 0) {
- assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region");
- return success();
- }
-
- auto ®ion = linalgOp->getRegion(0);
- if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region))
- return op->emitOpError("expected 1 region with 1 block");
-
- if (!linalgOp.getShapesToLoopsMap())
- return op->emitOpError("expected the shape-to-loops map to be non-null");
-
- // Simplifying assumption: bbargs match 1-1 with shape operands elemental
- // types.
- // TODO: once ranked shape types are plugged in, we may want to drop the
- // corresponding bbargs, that can never be read from. This will be subject to
- // consistency discussions (i.e. what to do with output tensors whose bbarg is
- // not used).
- Block &block = linalgOp->getRegion(0).front();
- unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables();
-
- if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments())
- return op->emitError("expected as many non-induction variable region "
- "arguments as the number of shaped operands");
-
- // Note: the number and type of yield values are checked in the YieldOp.
- for (unsigned i = 0; i < numBBIvs; ++i)
- if (!block.getArgument(i).getType().isIndex())
- return op->emitOpError("expected index block argument #") << i;
-
- unsigned idx = 0;
- for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(),
- block.getArguments().drop_front(numBBIvs))) {
- if (std::get<0>(it).getElementType() != std::get<1>(it).getType())
- return op->emitError("expected type of bb argument #")
- << (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")"
- << " to match element type of corresponding shaped operand ("
- << std::get<0>(it).getElementType() << ")";
- ++idx;
- }
-
- return success();
-}
-
namespace {
template <typename GenericOpType>
@@ -1901,8 +1622,6 @@ struct EraseDeadLinalgOp;
struct FoldTensorCastOp;
} // namespace
-#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
-
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 47841c840fe5..9bf763079470 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1863,12 +1863,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
let hasFolder = 1;
let hasCanonicalizer = 1;
- let extraClassDeclaration = [{{
+ let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
- static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
+ static std::function<void(Block &)> getRegionBuilder() {{
+ return regionBuilder;
+ }
// Generic methods.
static unsigned getNumRegionArgs() {{ return {4}; }
More information about the Mlir-commits
mailing list