[Mlir-commits] [mlir] d5d4fb6 - [mlir][linalg] Add support for using scalar attributes in TC ops.

Hanhan Wang llvmlistbot at llvm.org
Wed Mar 10 01:51:44 PST 2021


Author: Hanhan Wang
Date: 2021-03-10T01:51:12-08:00
New Revision: d5d4fb635ee0b3852f4fee1c98bac87c5811e051

URL: https://github.com/llvm/llvm-project/commit/d5d4fb635ee0b3852f4fee1c98bac87c5811e051
DIFF: https://github.com/llvm/llvm-project/commit/d5d4fb635ee0b3852f4fee1c98bac87c5811e051.diff

LOG: [mlir][linalg] Add support for using scalar attributes in TC ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D97876

Added: 
    

Modified: 
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 7627a017a07e..8c4c31da2fce 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -85,12 +85,12 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
 // Test attribute definitions
 // ODS-LABEL: def Test4Op
 // ODS: F32ArrayAttr:$array_attr,
-// ODS: F32:$f32_attr,
+// ODS: F32Attr:$f32_attr,
 // ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
-// ODS: I32:$i32_attr,
-// ODS: I64:$i64_attr,
+// ODS: I32Attr:$i32_attr,
+// ODS: I64Attr:$i64_attr,
 // ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
-// ODS: OptionalAttr<F32>:$optional_attr
+// ODS: OptionalAttr<F32Attr>:$optional_attr
 //
 // ODS: bool hasDynamicIndexingMaps();
 // ODS: LogicalResult verifyIndexingMapRequiredAttributes();

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 9d2d26a5cbd2..7165a0fe89fe 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1174,7 +1174,7 @@ class TCParser {
 
     // Returns the function to get values at the given indices from this
     // attribute.
-    std::string getValueFn(ArrayRef<uint64_t> indices) const;
+    llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1841,16 +1841,19 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
 
     const auto &dims = attr.second.vectorDims;
     if (!dims.empty()) {
+      // Vector case
       SmallVector<std::string, 4> dimStrs;
       for (uint64_t dim : dims)
         dimStrs.push_back(std::to_string(dim));
       odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
                               llvm::join(dimStrs, ", "));
-    }
-
-    assert(dims.empty() || !attr.second.isArray);
-    if (attr.second.isArray)
+    } else if (attr.second.isArray) {
+      // Array case
       odsType = llvm::formatv("{0}ArrayAttr", odsType);
+    } else {
+      // Scalar case
+      odsType = llvm::formatv("{0}Attr", odsType);
+    }
 
     if (attr.second.isOptional)
       odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
@@ -2242,13 +2245,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
     StringRef attrName = attrUse.value().attrName;
     auto it = registeredAttrs.find(attrName.str());
     assert(it != registeredAttrs.end() && "uses should point to valid attr!");
-    std::string getValueFn = it->second.getValueFn(attrUse.value().indices);
-    if (getValueFn.empty()) {
+    llvm::Optional<std::string> getValueFn =
+        it->second.getValueFn(attrUse.value().indices);
+    if (!getValueFn) {
       (void)parser.emitError("unimplemented getValueFn for attribute: " +
                              attrName);
       return;
     }
-    std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn);
+    std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn);
     const char *cstFmt =
         "\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
     mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
@@ -2374,10 +2378,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
                       expressionsStr, yieldStr);
 }
 
-std::string
+llvm::Optional<std::string>
 TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
   if (isArray)
-    return "";
+    return llvm::None;
 
   if (!vectorDims.empty()) {
     SmallVector<std::string, 4> indexStrs;
@@ -2385,20 +2389,20 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
       indexStrs.push_back(std::to_string(index));
     std::string indexList = llvm::join(indexStrs, ", ");
     if (elementType == "f32")
-      return llvm::formatv("getValue<float>({ {0} })", indexList);
+      return llvm::formatv(".getValue<float>({ {0} })", indexList).str();
     if (elementType == "i32")
-      return llvm::formatv("getValue<int>({ {0} })", indexList);
+      return llvm::formatv(".getValue<int>({ {0} })", indexList).str();
     if (elementType == "i64")
-      return llvm::formatv("getValue<int64_t>({ {0} })", indexList);
+      return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str();
 
-    return "";
+    return llvm::None;
   }
 
   if (elementType == "f32")
-    return "getValue().convertToFloat()";
+    return std::string(".convertToFloat()");
   if (elementType == "i32" || elementType == "i64")
-    return "getInt()";
-  return "";
+    return std::string("");
+  return llvm::None;
 }
 
 /// Iterate over each Tensor Comprehension def.


        


More information about the Mlir-commits mailing list