[Mlir-commits] [mlir] 3bc7555 - [mlir][linalg] Use attributes in named ops' indexing maps

Lei Zhang llvmlistbot at llvm.org
Wed Jan 13 07:08:31 PST 2021


Author: Lei Zhang
Date: 2021-01-13T10:04:49-05:00
New Revision: 3bc7555ffac0a803e44c4b1462e0c4c5eee865ea

URL: https://github.com/llvm/llvm-project/commit/3bc7555ffac0a803e44c4b1462e0c4c5eee865ea
DIFF: https://github.com/llvm/llvm-project/commit/3bc7555ffac0a803e44c4b1462e0c4c5eee865ea.diff

LOG: [mlir][linalg] Use attributes in named ops' indexing maps

This commit adds support for parsing attribute uses in indexing
maps. These attribute uses are represented as affine symbols in
the resultant indexing maps because we can only know their
concrete value (which are coming from op attributes and are
constants) for specific op instances. The `indxing_maps()`
calls are synthesized to read these attributes and create affine
constants to replace the placeholder affine symbols and simplify.

Depends on D94240

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D94335

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg.md
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 922455dddbda..1f8ef3c4021b 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -590,6 +590,12 @@ better adapt to Linalg:
     `i` (resp. `j`) is a parallel iterator encoded by affine dimension of
     position `0` (resp. `1`); `k` (resp. `l`) is a reduction iterator encoded by
     an affine dimension of position `2` (resp. `3`).
+1.  A list of attributes can be defined for the op with the format of `attr(
+    strides: 2xi32)` and referenced in comprehension like `strides[0]`. These
+    attribute uses will be parsed as affine symbols to generate op definition
+    and implementation. For a concrete op instance, the runtime constant values
+    from the attributes will be used to replace the affine symbols and simplify
+    the indexing maps.
 
 These decisions and syntax are subject to evolution and change. In particular,
 op-specific attributes, dynamic ranks, some form of templating, shape

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 1ef128760637..1ce2d2ac9418 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -10,9 +10,18 @@
 //       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
 //       IMPL:  ArrayAttr Test1Op::indexing_maps() {
-//       IMPL:  AffineMap::get(2, 0, {d0, d1}, context),
-//  IMPL-NEXT:  AffineMap::get(2, 0, {d1}, context),
-//  IMPL-NEXT:  AffineMap::get(2, 0, {d0}, context) });
+//       IMPL: auto s0 = getAffineSymbolExpr(0, context); (void)s0;
+//  IMPL-NEXT: auto s1 = getAffineSymbolExpr(1, context); (void)s1;
+//  IMPL-NEXT: auto map0 = AffineMap::get(2, 2, {d0, d1}, context);
+//  IMPL-NEXT: map0 = map0.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0);
+//  IMPL-NEXT: map0 = simplifyAffineMap(map0);
+//  IMPL-NEXT: auto map1 = AffineMap::get(2, 2, {d1}, context);
+//  IMPL-NEXT: map1 = map1.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0);
+//  IMPL-NEXT: map1 = simplifyAffineMap(map1);
+//  IMPL-NEXT: auto map2 = AffineMap::get(2, 2, {d0}, context);
+//  IMPL-NEXT: map2 = map2.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0);
+//  IMPL-NEXT: map2 = simplifyAffineMap(map2);
+//  IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
 //
 //       IMPL:  void Test1Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@@ -34,9 +43,9 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
 //       IMPL:  ArrayAttr Test2Op::indexing_maps() {
-//       IMPL:  AffineMap::get(3, 0, {d0, d2}, context),
-//  IMPL-NEXT:  AffineMap::get(3, 0, {d2, d1}, context),
-//  IMPL-NEXT:  AffineMap::get(3, 0, {d0, d1}, context) });
+//       IMPL:  AffineMap::get(3, 3, {d0, d2}, context)
+//       IMPL:  AffineMap::get(3, 3, {d2, d1}, context)
+//       IMPL:  AffineMap::get(3, 3, {d0, d1}, context)
 //
 //       IMPL:  Test2Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@@ -58,9 +67,9 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
 //       IMPL:  ArrayAttr Test3Op::indexing_maps() {
-//       IMPL:  AffineMap::get(4, 0, {d0, d1, d3}, context),
-//  IMPL-NEXT:  AffineMap::get(4, 0, {d3, d2}, context),
-//  IMPL-NEXT:  AffineMap::get(4, 0, {d0, d1, d2}, context) });
+//       IMPL:  AffineMap::get(4, 4, {d0, d1, d3}, context)
+//       IMPL:  AffineMap::get(4, 4, {d3, d2}, context)
+//       IMPL:  AffineMap::get(4, 4, {d0, d1, d2}, context)
 //
 //       IMPL:  Test3Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@@ -94,3 +103,25 @@ attr(
 ) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
 }
+
+// Test attribute usage in affine expressions
+// IMPL-LABEL: ArrayAttr Test5Op::indexing_maps() {
+// IMPL: auto cst0 = getAffineConstantExpr(strides().getValue<int>({ 0 }), context);
+// IMPL: auto cst1 = getAffineConstantExpr(strides().getValue<int>({ 1 }), context);
+// IMPL: auto map0 = AffineMap::get(7, 9, {d0, d1 * s7 + d4, d2 * s8 + d5, d6}, context);
+// IMPL: map0 = map0.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0);
+// IMPL: map0 = simplifyAffineMap(map0);
+// IMPL: auto map1 = AffineMap::get(7, 9, {d3, d4, d5, d6}, context);
+// IMPL: map1 = map1.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0);
+// IMPL: map1 = simplifyAffineMap(map1);
+// IMPL: auto map2 = AffineMap::get(7, 7, {d0, d1, d2, d3}, context);
+// IMPL: map2 = map2.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0);
+// IMPL: map2 = simplifyAffineMap(map2);
+// IMPL: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
+//
+ods_def<Test5Op>:
+def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F))
+     attr(strides: 2xi32) {
+  O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
+    I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c)));
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 138c5a4e904e..cb7bfd2c9c4d 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -19,8 +19,12 @@
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/Twine.h"
@@ -366,6 +370,14 @@ class Parser {
   // Lexer Utilities
   //===--------------------------------------------------------------------===//
 
+  LogicalResult parseInteger(uint64_t &value) {
+    if (!curToken.is(Token::Kind::integer))
+      return emitError(curToken.getLoc(), "expected integer");
+    value = curToken.getUInt64IntegerValue().getValue();
+    consumeToken();
+    return success();
+  }
+
   /// Advance the current lexer onto the next token.
   void consumeToken() {
     assert(curToken.getKind() != Token::Kind::eof &&
@@ -447,6 +459,30 @@ class Parser {
 };
 } // namespace
 
+/// Encodes an attribute use of the form:
+///
+///   index-list ::= integer-literal (`,` integer-literal)*
+///   attr-use ::= bare-id `[` index-list `]`
+struct AttrUse {
+  // Referenced attribute
+  StringRef attrName;
+  // Indices into the attribute
+  SmallVector<uint64_t, 4> indices;
+  /// Affine symbol for this usage.
+  /// This is represented as an affine symbol because at the time of parsing the
+  /// spec and generating the op's ODS/C++, we don't know the concrete constant
+  /// value. But they should be replaced with constants read from the attribute
+  /// and thus folded away for concrete op instances.
+  AffineExpr symbol;
+
+  std::string getKey() {
+    SmallVector<std::string, 4> indexStrs;
+    for (uint64_t index : indices)
+      indexStrs.push_back(std::to_string(index));
+    return llvm::formatv("{0}[{1}]", attrName, llvm::join(indexStrs, ","));
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Affine parsing.
 //===----------------------------------------------------------------------===//
@@ -479,10 +515,21 @@ using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
 /// This is a specialized parser for affine expressions.
 class AffineParser {
 public:
-  explicit AffineParser(Parser &p,
-                        std::function<AffineExpr(StringRef)> bareIdParsingHook,
-                        AffineDimList &dimList, AffineSymbolList &symbolList)
-      : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList),
+  /// Creates an affine parser that parses tokens from `p`.
+  ///
+  /// The affine parser introduces new dimensions and symbols eagerly as new
+  /// `id` are discovered. To additionally support attribute use `id`s, for a
+  /// parsed `id`, the resolution mechanism proceeds as follows:
+  /// 1. Try to parse `id` as an attribute use (using the `attrUseParsingHook`).
+  /// 2. If unsuccessful, try to match `id` to a known dim or symbol.
+  /// 3. If still unsuccessful, eagerly create a new dim or symbol and add it to
+  ///    the known dims or symbols (using the `bareIdParsingHook`).
+  explicit AffineParser(
+      Parser &p, std::function<AffineExpr(StringRef)> bareIdParsingHook,
+      std::function<llvm::Optional<AffineExpr>()> attrUseParsingHook,
+      AffineDimList &dimList, AffineSymbolList &symbolList)
+      : parser(p), bareIdFallback(bareIdParsingHook),
+        attrUseCallback(attrUseParsingHook), dims(dimList),
         symbols(symbolList) {}
 
   /// Parse a comma-separated list of affine exprs.
@@ -502,6 +549,7 @@ class AffineParser {
   AffineExpr parseParentheticalExpr();
   AffineExpr parseNegateExpression(AffineExpr lhs);
   AffineExpr parseIntegerExpr();
+  AffineExpr parseAttrUseOrBareIdExpr();
   AffineExpr parseBareIdExpr();
 
   AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
@@ -515,6 +563,7 @@ class AffineParser {
 
   Parser &parser;
   std::function<AffineExpr(StringRef)> bareIdFallback;
+  std::function<llvm::Optional<AffineExpr>()> attrUseCallback;
   AffineDimList &dims;
   AffineSymbolList &symbols;
 };
@@ -688,6 +737,12 @@ AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
   return (-1) * operand;
 }
 
+AffineExpr AffineParser::parseAttrUseOrBareIdExpr() {
+  if (llvm::Optional<AffineExpr> attrUse = attrUseCallback())
+    return attrUse.getValue();
+  return parseBareIdExpr();
+}
+
 /// Parse a bare id that may appear in an affine expression.
 ///
 ///   affine-expr ::= bare-id
@@ -739,7 +794,7 @@ AffineExpr AffineParser::parseIntegerExpr() {
 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
   switch (parser.curToken.getKind()) {
   case Token::Kind::id:
-    return parseBareIdExpr();
+    return parseAttrUseOrBareIdExpr();
   case Token::Kind::integer:
     return parseIntegerExpr();
   case Token::Kind::l_paren:
@@ -994,8 +1049,12 @@ class TCParser {
   LogicalResult parseTensorUse(TensorUse &result,
                                ComprehensionParsingState &state);
 
+  /// Parses an attribute definition.
   LogicalResult parseAttrDef();
 
+  /// Parses an optional attribute use.
+  LogicalResult parseAttrUse(AttrUse &result);
+
   /// Parses a tensor expression.
   LogicalResult parseExpression(TensorUse currentDefinition,
                                 std::unique_ptr<Expression> &result,
@@ -1053,6 +1112,10 @@ class TCParser {
     SmallVector<uint64_t, 4> vectorDims;
     bool isArray;
     bool isOptional;
+
+    // Returns the function to get values at the given indices from this
+    // attribute.
+    std::string getValueFn(ArrayRef<uint64_t> indices) const;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1061,6 +1124,9 @@ class TCParser {
   /// Symbols are per TC def.
   AffineSymbolList symbols;
 
+  /// Attribute usages in all affine expressions.
+  SmallVector<AttrUse, 8> attrUses;
+
   /// Tensors are per TC def.
   llvm::StringMap<RegisteredTensor> registeredTensors;
   unsigned nextRegisteredTensorIndex;
@@ -1147,20 +1213,45 @@ SmallVector<AffineExpr, 4>
 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
                            AffineDimList &dims, Token::Kind lDelim,
                            Token::Kind rDelim) {
-  AffineParser affineParser(
-      parser,
-      [&](StringRef sRef) {
-        AffineExpr expr;
-        if (discoveryMode == EagerDiscoveryMode::Symbols) {
-          expr = getAffineSymbolExpr(symbols.size(), parser.context);
-          symbols.emplace_back(sRef, expr);
-        } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
-          expr = getAffineDimExpr(dims.size(), parser.context);
-          dims.emplace_back(sRef, expr);
-        }
-        return expr;
-      },
-      dims, symbols);
+  auto createAffineBareId = [&](StringRef sRef) {
+    AffineExpr expr;
+    if (discoveryMode == EagerDiscoveryMode::Symbols) {
+      expr = getAffineSymbolExpr(symbols.size(), parser.context);
+      symbols.emplace_back(sRef, expr);
+    } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
+      expr = getAffineDimExpr(dims.size(), parser.context);
+      dims.emplace_back(sRef, expr);
+    }
+    return expr;
+  };
+
+  auto tryToParseAttrUse = [&]() -> llvm::Optional<AffineExpr> {
+    if (!parser.curToken.is(Token::Kind::id))
+      return llvm::None;
+
+    StringRef attrName = parser.curToken.getSpelling();
+    auto it = registeredAttrs.find(attrName.str());
+    if (it == registeredAttrs.end())
+      return llvm::None;
+
+    AttrUse result;
+    if (failed(parseAttrUse(result)))
+      return llvm::None;
+
+    // We create a new symbol for each attribute usage without reuse. This is
+    // fine given these symbols will be replaced with constants and folded away
+    // for concrete op instances.
+    result.symbol = getAffineSymbolExpr(symbols.size(), parser.context);
+    // Merely for taking the index. We don't reuse anyway.
+    symbols.emplace_back("<attr-use>", result.symbol);
+
+    attrUses.push_back(result);
+
+    return result.symbol;
+  };
+
+  AffineParser affineParser(parser, createAffineBareId, tryToParseAttrUse, dims,
+                            symbols);
   return affineParser.parseAffineExprs(lDelim, rDelim);
 }
 
@@ -1241,8 +1332,9 @@ LogicalResult TCParser::parseAttrDef() {
   // Parse potential dimension list
   SmallVector<uint64_t, 4> vectorDims;
   while (parser.curToken.is(Token::Kind::integer)) {
-    vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue());
-    parser.consumeToken();
+    uint64_t value;
+    parser.parseInteger(value);
+    vectorDims.push_back(value);
 
     StringRef spelling = parser.curToken.getSpelling();
     if (spelling[0] != 'x')
@@ -1286,6 +1378,44 @@ LogicalResult TCParser::parseAttrDef() {
   return success();
 }
 
+LogicalResult TCParser::parseAttrUse(AttrUse &result) {
+  result.attrName = parser.curToken.getSpelling();
+  if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
+    return failure();
+
+  auto it = registeredAttrs.find(result.attrName.str());
+  assert(it != registeredAttrs.end());
+  const RegisteredAttr &attr = it->second;
+
+  if (!attr.vectorDims.empty() || attr.isArray) {
+    // This is a vector/array attribute. Parse indices for it.
+    auto indexLoc = parser.curToken.getLoc();
+
+    if (failed(parser.parseToken(Token::Kind::l_square, "expected '['")))
+      return failure();
+
+    auto parseIndex = [&]() {
+      uint64_t value;
+      if (failed(parser.parseInteger(value)))
+        return failure();
+      result.indices.push_back(value);
+      return success();
+    };
+    if (failed(parser.parseCommaSeparatedListUntil(
+            Token::Kind::r_square, parseIndex, /*allowEmptyList=*/false)))
+      return failure();
+
+    size_t rank = attr.isArray ? 1 : attr.vectorDims.size();
+    if (result.indices.size() != rank)
+      return parser.emitError(indexLoc,
+                              "number of indices mismatch: expected " +
+                                  std::to_string(rank) + ", but found " +
+                                  std::to_string(result.indices.size()));
+  }
+
+  return success();
+}
+
 /// Parses a tensor expression of the form:
 ///
 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
@@ -1776,7 +1906,8 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
     MLIRContext *context = getContext();
     AffineExpr {1};
     bindDims(context, {1});
-    return Builder(context).getAffineMapArrayAttr({ {2} });
+    {2}
+    return Builder(context).getAffineMapArrayAttr({ {3} });
   })FMT";
 
   // 2. Print a comma-separated list of identifiers for the AffineExpr in
@@ -1790,36 +1921,89 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
       [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
   ss.flush();
 
-  // 3. Print a comma-separated list of AffineMap constructors that use the
-  // identifiers from 1. The AffineExpr use the common arithmetic operators on
-  // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
-  // in return `SmallVector<AffineMap, 8>{{ {2} };`.
+  // 3. Get the list of affine maps for each input/output. The AffineExpr use
+  // the common arithmetic operators on AffineExpr. These affine maps will
+  // replace the `{2}` placeholder.
   std::string mapsStr;
   llvm::raw_string_ostream mapsStringStream(mapsStr);
+
   SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
   for (const auto &it : state.orderedTensorArgs)
     orderedUses[it.second] = it.first;
-  llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
-    assert(u.indexingMap);
-    const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)";
-    if (u.indexingMap.isEmpty()) {
-      mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
+
+  // Create a list of all symbols.
+  SmallVector<std::string, 4> symbolReplacements;
+  symbolReplacements.reserve(symbols.size());
+  for (unsigned i = 0; i < symbols.size(); ++i) {
+    const char *symFmt =
+        "\n\tauto s{0} = getAffineSymbolExpr({0}, context); (void)s{0};";
+    mapsStringStream << llvm::formatv(symFmt, i);
+    symbolReplacements.push_back(llvm::formatv("s{0}", i));
+  }
+
+  // Create the affine constant expressions to replace symbols for attributes.
+  for (auto attrUse : llvm::enumerate(attrUses)) {
+    StringRef attrName = attrUse.value().attrName;
+    auto it = registeredAttrs.find(attrName.str());
+    assert(it != registeredAttrs.end() && "uses should point to valid attr!");
+    std::string getValueFn = it->second.getValueFn(attrUse.value().indices);
+    if (getValueFn.empty()) {
+      parser.emitError("unimplemented getValueFn for attribute: " + attrName);
       return;
     }
+    std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn);
+    const char *cstFmt =
+        "\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
+    mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
+
+    unsigned position =
+        attrUse.value().symbol.cast<AffineSymbolExpr>().getPosition();
+    symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
+  }
+
+  // For each tensor use, construct the affine map, replace symbols by the
+  // corresponding attribute values, and simplify the affine map.
+  for (auto tensorUse : llvm::enumerate(orderedUses)) {
+    auto indexingMap = tensorUse.value().indexingMap;
+    const char *mapFmt =
+        "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
 
     std::string exprsStr;
     llvm::raw_string_ostream exprsStringStream(exprsStr);
     exprsStringStream << "{";
-    llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream);
+    llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
     exprsStringStream << "}";
     exprsStringStream.flush();
+    mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(),
+                                      state.dims.size(),
+                                      indexingMap.getNumSymbols(), exprsStr);
+
+    std::string replaceSymbolList =
+        llvm::formatv("{ {0} }", llvm::join(symbolReplacements, ", "));
+
+    // Note that we use `0` as the result affine map's number of symbols. All
+    // symbols representing attribute usages should be folded away. But there
+    // may exist additional symbols for tensor dimension upper bounds. Linalg
+    // does not handle such cases right now. This needs to be fixed once we need
+    // that.
+    const char *replaceFmt =
+        "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
+    mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(),
+                                      replaceSymbolList, state.dims.size());
+    const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
+    mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index());
+  }
 
-    mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr);
-  });
   mapsStringStream.flush();
 
+  SmallVector<std::string, 4> mapList;
+  mapList.reserve(orderedUses.size());
+  for (unsigned i = 0; i < orderedUses.size(); ++i)
+    mapList.push_back(llvm::formatv("map{0}", i));
+
   // 4. Apply format to 1. using 2. and 3.
-  os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
+  os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr,
+                      llvm::join(mapList, ", "));
 }
 
 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
@@ -1893,6 +2077,31 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
                       expressionsStr, yieldStr);
 }
 
+std::string
+TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
+  if (isArray)
+    return "";
+
+  if (!vectorDims.empty()) {
+    SmallVector<std::string, 4> indexStrs;
+    for (uint64_t index : indices)
+      indexStrs.push_back(std::to_string(index));
+    std::string indexList = llvm::join(indexStrs, ", ");
+    if (elementType == "f32")
+      return llvm::formatv("getValue<float>({ {0} })", indexList);
+    if (elementType == "i32")
+      return llvm::formatv("getValue<int>({ {0} })", indexList);
+
+    return "";
+  }
+
+  if (elementType == "f32")
+    return "getValue().convertToFloat()";
+  if (elementType == "i32")
+    return "getInt()";
+  return "";
+}
+
 /// Iterate over each Tensor Comprehension def.
 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
                                                   Parser &parser) {


        


More information about the Mlir-commits mailing list