[Mlir-commits] [mlir] 855a119 - [mlir][linalg] Allow TC ops taking an unused shaped operand.
Hanhan Wang
llvmlistbot at llvm.org
Fri Feb 26 06:46:21 PST 2021
Author: Hanhan Wang
Date: 2021-02-26T06:45:56-08:00
New Revision: 855a1196049705344ec90cc1f3fd09b426416311
URL: https://github.com/llvm/llvm-project/commit/855a1196049705344ec90cc1f3fd09b426416311
DIFF: https://github.com/llvm/llvm-project/commit/855a1196049705344ec90cc1f3fd09b426416311.diff
LOG: [mlir][linalg] Allow TC ops taking an unused shaped operand.
If one operand is not used in the formula, it will be considered a
shaped operand. And the result of indexing map of the operand will be the first
reduction dims.
Depends On D97383
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D97384
Added:
Modified:
mlir/docs/Dialects/Linalg.md
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 4efb0d1c9f00..e606e67c1e7c 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -582,8 +582,9 @@ better adapt to Linalg:
resorting to more general MLIR parsing.
1. Reduction dimensions are specified with angle bracket notation on the
operation they apply to (e.g. `std_add<k>` specifies that `k` is a reduction
- dimension). In TC, a reduction is specified with `op=` operator and the
- reduction dimensions are inferred.
+ dimension). In TC, the reduction dimensions are inferred. If one of the
+ operand is not used in any expressions, it will be considered a shape-only
+ operand, and the result of the indexing_map will be reduction dimensions.
1. The parallel and reduction dimension are ordered by the textual program
order. For instance, in the comprehension `O(i, j) = std_add<k, l>(...)`,
`i` (resp. `j`) is a parallel iterator encoded by affine dimension of
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index f670ac9a3c05..8a230857dbf6 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -190,3 +190,14 @@ def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
{
C(m) = std_subf<k>(std_mulf(A(m, k), B(k)), C(m));
}
+
+// Test shape-only operand.
+// IMPL-LABEL: ArrayAttr Test9Op::indexing_maps() {
+// IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context);
+// IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context);
+// IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context);
+ods_def<Test9Op>:
+def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+{
+ C(m) = std_addf<k>(C(m), A(m, k));
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 794e14780ea2..f557fa88308f 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1634,7 +1634,26 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
tensor.indexingMap = use.indexingMap;
state.orderedTensorArgs[use] = tensor.index;
});
- state.numArgs = seenDefs.size();
+ // If more than one definitions are less. They are shaped-only operand, which
+ // are used to define reduction loops. For now, only accept exactly one
+ // shaped-only operand.
+ if (state.numArgs > seenDefs.size() + 1) {
+ failed = true;
+ } else if (state.numArgs == seenDefs.size() + 1) {
+ for (auto &tensorIter : registeredTensors) {
+ auto &tensor = tensorIter.getValue();
+ if (tensor.indexingMap)
+ continue;
+ if (auto *pTensorExpr =
+ dyn_cast<TensorExpr>(state.expressions[0].get())) {
+ SmallVector<AffineExpr, 4> exprs;
+ for (auto dim : pTensorExpr->reductionDimensions)
+ exprs.push_back(getAffineDimExpr(dim, parser.context));
+ tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(),
+ exprs, parser.context);
+ }
+ }
+ }
if (failed)
return failure();
@@ -1762,6 +1781,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
while (parser.curToken.isNot(Token::Kind::r_brace)) {
perComprehensionStates.push_back(ComprehensionParsingState());
+ perComprehensionStates.back().numArgs = registeredTensors.size();
if (failed(parseOneComprehension(cppOpName, tcName,
perComprehensionStates.back())))
return failure();
@@ -2207,10 +2227,6 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
std::string mapsStr;
llvm::raw_string_ostream mapsStringStream(mapsStr);
- SmallVector<TensorUse, 4> orderedUses(state.numArgs);
- for (const auto &it : state.orderedTensorArgs)
- orderedUses[it.second] = it.first;
-
// Create a list of all symbols.
SmallVector<std::string, 4> symbolReplacements;
symbolReplacements.reserve(symbols.size());
@@ -2242,10 +2258,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
}
- // For each tensor use, construct the affine map, replace symbols by the
- // corresponding attribute values, and simplify the affine map.
- for (auto tensorUse : llvm::enumerate(orderedUses)) {
- auto indexingMap = tensorUse.value().indexingMap;
+ // For each registered tensor, construct the affine map, replace symbols by
+ // the corresponding attribute values, and simplify the affine map.
+ for (auto &tensorIter : registeredTensors) {
+ auto &tensor = tensorIter.getValue();
+ auto indexingMap = tensor.indexingMap;
const char *mapFmt =
"\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
@@ -2255,8 +2272,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
exprsStringStream << "}";
exprsStringStream.flush();
- mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(),
- state.dims.size(),
+ mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(),
indexingMap.getNumSymbols(), exprsStr);
std::string replaceSymbolList =
@@ -2269,17 +2285,17 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
// need that.
const char *replaceFmt =
"\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
- mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(),
+ mapsStringStream << llvm::formatv(replaceFmt, tensor.index,
replaceSymbolList, state.dims.size());
const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
- mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index());
+ mapsStringStream << llvm::formatv(simplifyFmt, tensor.index);
}
mapsStringStream.flush();
SmallVector<std::string, 4> mapList;
- mapList.reserve(orderedUses.size());
- for (unsigned i = 0; i < orderedUses.size(); ++i)
+ mapList.reserve(state.numArgs);
+ for (auto i : llvm::seq<unsigned>(0, state.numArgs))
mapList.push_back(llvm::formatv("map{0}", i));
// 4. Apply format to 1. using 2. and 3.
More information about the Mlir-commits
mailing list