[Mlir-commits] [mlir] 855a119 - [mlir][linalg] Allow TC ops taking an unused shaped operand.

Hanhan Wang llvmlistbot at llvm.org
Fri Feb 26 06:46:21 PST 2021


Author: Hanhan Wang
Date: 2021-02-26T06:45:56-08:00
New Revision: 855a1196049705344ec90cc1f3fd09b426416311

URL: https://github.com/llvm/llvm-project/commit/855a1196049705344ec90cc1f3fd09b426416311
DIFF: https://github.com/llvm/llvm-project/commit/855a1196049705344ec90cc1f3fd09b426416311.diff

LOG: [mlir][linalg] Allow TC ops taking an unused shaped operand.

If one operand is not used in the formula, it will be considered a
shaped operand. And the result of indexing map of the operand will be the first
reduction dims.

Depends On D97383

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D97384

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg.md
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 4efb0d1c9f00..e606e67c1e7c 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -582,8 +582,9 @@ better adapt to Linalg:
     resorting to more general MLIR parsing.
 1.  Reduction dimensions are specified with angle bracket notation on the
     operation they apply to (e.g. `std_add<k>` specifies that `k` is a reduction
-    dimension). In TC, a reduction is specified with `op=` operator and the
-    reduction dimensions are inferred.
+    dimension). In TC, the reduction dimensions are inferred. If one of the
+    operand is not used in any expressions, it will be considered a shape-only
+    operand, and the result of the indexing_map will be reduction dimensions.
 1.  The parallel and reduction dimension are ordered by the textual program
     order. For instance, in the comprehension `O(i, j) = std_add<k, l>(...)`,
     `i` (resp. `j`) is a parallel iterator encoded by affine dimension of

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index f670ac9a3c05..8a230857dbf6 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -190,3 +190,14 @@ def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
 {
   C(m) = std_subf<k>(std_mulf(A(m, k), B(k)), C(m));
 }
+
+// Test shape-only operand.
+// IMPL-LABEL:  ArrayAttr Test9Op::indexing_maps() {
+//       IMPL:    auto map0 = AffineMap::get(2, 2, {d0, d1}, context);
+//       IMPL:    auto map1 = AffineMap::get(2, 2, {d1}, context);
+//       IMPL:    auto map2 = AffineMap::get(2, 2, {d0}, context);
+ods_def<Test9Op>:
+def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+{
+  C(m) = std_addf<k>(C(m), A(m, k));
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 794e14780ea2..f557fa88308f 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1634,7 +1634,26 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
       tensor.indexingMap = use.indexingMap;
       state.orderedTensorArgs[use] = tensor.index;
     });
-  state.numArgs = seenDefs.size();
+  // If more than one definitions are less. They are shaped-only operand, which
+  // are used to define reduction loops. For now, only accept exactly one
+  // shaped-only operand.
+  if (state.numArgs > seenDefs.size() + 1) {
+    failed = true;
+  } else if (state.numArgs == seenDefs.size() + 1) {
+    for (auto &tensorIter : registeredTensors) {
+      auto &tensor = tensorIter.getValue();
+      if (tensor.indexingMap)
+        continue;
+      if (auto *pTensorExpr =
+              dyn_cast<TensorExpr>(state.expressions[0].get())) {
+        SmallVector<AffineExpr, 4> exprs;
+        for (auto dim : pTensorExpr->reductionDimensions)
+          exprs.push_back(getAffineDimExpr(dim, parser.context));
+        tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(),
+                                            exprs, parser.context);
+      }
+    }
+  }
   if (failed)
     return failure();
 
@@ -1762,6 +1781,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
   SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
   while (parser.curToken.isNot(Token::Kind::r_brace)) {
     perComprehensionStates.push_back(ComprehensionParsingState());
+    perComprehensionStates.back().numArgs = registeredTensors.size();
     if (failed(parseOneComprehension(cppOpName, tcName,
                                      perComprehensionStates.back())))
       return failure();
@@ -2207,10 +2227,6 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
   std::string mapsStr;
   llvm::raw_string_ostream mapsStringStream(mapsStr);
 
-  SmallVector<TensorUse, 4> orderedUses(state.numArgs);
-  for (const auto &it : state.orderedTensorArgs)
-    orderedUses[it.second] = it.first;
-
   // Create a list of all symbols.
   SmallVector<std::string, 4> symbolReplacements;
   symbolReplacements.reserve(symbols.size());
@@ -2242,10 +2258,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
     symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
   }
 
-  // For each tensor use, construct the affine map, replace symbols by the
-  // corresponding attribute values, and simplify the affine map.
-  for (auto tensorUse : llvm::enumerate(orderedUses)) {
-    auto indexingMap = tensorUse.value().indexingMap;
+  // For each registered tensor, construct the affine map, replace symbols by
+  // the corresponding attribute values, and simplify the affine map.
+  for (auto &tensorIter : registeredTensors) {
+    auto &tensor = tensorIter.getValue();
+    auto indexingMap = tensor.indexingMap;
     const char *mapFmt =
         "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
 
@@ -2255,8 +2272,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
     llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
     exprsStringStream << "}";
     exprsStringStream.flush();
-    mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(),
-                                      state.dims.size(),
+    mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(),
                                       indexingMap.getNumSymbols(), exprsStr);
 
     std::string replaceSymbolList =
@@ -2269,17 +2285,17 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
     // need that.
     const char *replaceFmt =
         "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
-    mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(),
+    mapsStringStream << llvm::formatv(replaceFmt, tensor.index,
                                       replaceSymbolList, state.dims.size());
     const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
-    mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index());
+    mapsStringStream << llvm::formatv(simplifyFmt, tensor.index);
   }
 
   mapsStringStream.flush();
 
   SmallVector<std::string, 4> mapList;
-  mapList.reserve(orderedUses.size());
-  for (unsigned i = 0; i < orderedUses.size(); ++i)
+  mapList.reserve(state.numArgs);
+  for (auto i : llvm::seq<unsigned>(0, state.numArgs))
     mapList.push_back(llvm::formatv("map{0}", i));
 
   // 4. Apply format to 1. using 2. and 3.


        


More information about the Mlir-commits mailing list