[Mlir-commits] [mlir] 0acc260 - [mlir][linalg] Support generating builders for named op attributes

Lei Zhang llvmlistbot at llvm.org
Fri Jan 15 06:00:38 PST 2021


Author: Lei Zhang
Date: 2021-01-15T09:00:30-05:00
New Revision: 0acc260b574e28f5247e8ad4d8c9805b8005c841

URL: https://github.com/llvm/llvm-project/commit/0acc260b574e28f5247e8ad4d8c9805b8005c841
DIFF: https://github.com/llvm/llvm-project/commit/0acc260b574e28f5247e8ad4d8c9805b8005c841.diff

LOG: [mlir][linalg] Support generating builders for named op attributes

This commit adds support to generate an additional builder for
each named op that has attributes. This gives better experience
when creating the named ops.

Along the way adds support for i64.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D94733

Added: 
    

Modified: 
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 226a09669b1c..9bd6152f07da 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -88,6 +88,7 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
 // ODS: F32:$f32_attr,
 // ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
 // ODS: I32:$i32_attr,
+// ODS: I64:$i64_attr,
 // ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
 // ODS: OptionalAttr<F32>:$optional_attr
 //
@@ -96,6 +97,7 @@ def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
 attr(
   f32_attr: f32,
   i32_attr: i32,
+  i64_attr: i64,
   fvec_attr: 4xf32,
   ivec_attr: 5x6xi32,
   array_attr : f32[],
@@ -126,6 +128,7 @@ def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F))
     I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c)));
 }
 
+// Test documentation
 // ODS-LABEL: def Test6Op
 // ODS:       let summary = [{ My magic op. }];
 // ODS-NEXT:  let description = [{
@@ -144,3 +147,18 @@ It has one output.
 {
   C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
 }
+
+// Test attribute builder
+// ODS-LABEL: def Test7Op
+// ODS:         OpBuilderDAG<
+// ODS:           (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+// ODS:            "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b)
+// ODS:           $_state.addAttribute("attr_a", attr_a);
+// ODS:           $_state.addAttribute("attr_b", attr_b);
+//
+ods_def<Test7Op>:
+def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+     attr(attr_a: f32, attr_b: 4xi32)
+{
+  C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index f4b7f9f9323a..47841c840fe5 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1768,6 +1768,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
     std::string odsType = llvm::StringSwitch<std::string>(elementType)
                               .Case("f32", "F32")
                               .Case("i32", "I32")
+                              .Case("i64", "I64")
                               .Default("");
     if (odsType.empty()) {
       parser.emitError("unimplemented support for attribute element type: " +
@@ -1811,7 +1812,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       let regions = (region AnyRegion:$region);
 
       let skipDefaultBuilders = 1;
-      let builders = [ OpBuilderDAG<
+      let builders = [
+        OpBuilderDAG<
         (ins "ValueRange":$inputs, "ValueRange":$outputs),
         [{{
           $_state.addOperands(inputs);
@@ -1826,7 +1828,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_state,
             TypeRange(inputs),
             TypeRange(outputs));
-        }]>, OpBuilderDAG<
+        }]>,
+        OpBuilderDAG<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
              "ValueRange":$outputs),
         [{{
@@ -1843,7 +1846,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_state,
             TypeRange(inputs),
             TypeRange(outputs));
-        }]>, OpBuilderDAG<
+        }]>,
+        OpBuilderDAG<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
              CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
         [{{
@@ -1852,6 +1856,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_state.addTypes(resultTensorTypes);
           (void)$_state.addRegion();
         }]>
+        {5}
       ];
       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
       let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
@@ -1873,8 +1878,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       }];
   })FMT";
 
+  // Generate documentation.
   std::string doc;
-
   if (!docString.empty()) {
     const char *docFmt = R"FMT(
       let summary = [{ {0} }];
@@ -1888,8 +1893,47 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
     doc = llvm::formatv(docFmt, summary.trim(), description.trim());
   }
 
+  // Generate an additional builder that has parameters for attributes.
+  std::string attrBuilder;
+  if (!registeredAttrs.empty()) {
+    SmallVector<std::string, 4> attrParams, attrStmts;
+    for (const auto &attr : registeredAttrs) {
+      llvm::StringRef name = attr.first;
+      attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name));
+      attrStmts.push_back(
+          llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name));
+    }
+    std::string attrParamsList = llvm::join(attrParams, ", ");
+    std::string attrStmtsList = llvm::join(attrStmts, "\n");
+
+    const char *builderFmt = R"FMT(
+      , OpBuilderDAG<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+           "ValueRange":$outputs, {1}),
+      [{{
+        $_state.addOperands(inputs);
+        $_state.addOperands(outputs);
+        $_state.addTypes(resultTensorTypes);
+        $_state.addAttribute(
+          "operand_segment_sizes",
+          $_builder.getI32VectorAttr({{
+            static_cast<int32_t>(inputs.size()),
+            static_cast<int32_t>(outputs.size())}));
+        buildNamedStructuredOpRegionAndAttributes<{0}>(
+          $_builder,
+          $_state,
+          TypeRange(inputs),
+          TypeRange(outputs));
+        {2}
+      }]>
+    )FMT";
+    attrBuilder =
+        llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList);
+  }
+
+  // Finally put everything together.
   os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList,
-                      state.orderedTensorArgs.size());
+                      state.orderedTensorArgs.size(), attrBuilder);
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2146,13 +2190,15 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
       return llvm::formatv("getValue<float>({ {0} })", indexList);
     if (elementType == "i32")
       return llvm::formatv("getValue<int>({ {0} })", indexList);
+    if (elementType == "i64")
+      return llvm::formatv("getValue<int64_t>({ {0} })", indexList);
 
     return "";
   }
 
   if (elementType == "f32")
     return "getValue().convertToFloat()";
-  if (elementType == "i32")
+  if (elementType == "i32" || elementType == "i64")
     return "getInt()";
   return "";
 }


        


More information about the Mlir-commits mailing list