[Mlir-commits] [mlir] 564bcf9 - Align adaptor's generator accessors for attribute on the Op class

Mehdi Amini llvmlistbot at llvm.org
Tue Jan 4 21:42:26 PST 2022


Author: Mehdi Amini
Date: 2022-01-05T05:42:15Z
New Revision: 564bcf9d0243ef4f3f5ed848f7195fd0489c77a5

URL: https://github.com/llvm/llvm-project/commit/564bcf9d0243ef4f3f5ed848f7195fd0489c77a5
DIFF: https://github.com/llvm/llvm-project/commit/564bcf9d0243ef4f3f5ed848f7195fd0489c77a5.diff

LOG: Align adaptor's generator accessors for attribute on the Op class

Each attribute has two accessor: one suffixed with `Attr` which returns the attribute itself
and one without the suffix which unwrap the attribute.
For example for a StringAttr attribute with a field named `kind`, we'll generate:

StringAttr getKindAttr();
StringRef getKind();

Differential Revision: https://reviews.llvm.org/D116466

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 828f0ef151203..0d3157d515de7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -209,7 +209,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
   } while (moduleOp.lookupSymbol(stringConstName));
 
-  llvm::SmallString<20> formatString(adaptor.format().getValue());
+  llvm::SmallString<20> formatString(adaptor.format());
   formatString.push_back('\0'); // Null terminate for C
   size_t formatStringSize = formatString.size_in_bytes();
 
@@ -309,7 +309,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
   } while (moduleOp.lookupSymbol(stringConstName));
 
-  llvm::SmallString<20> formatString(adaptor.format().getValue());
+  llvm::SmallString<20> formatString(adaptor.format());
   formatString.push_back('\0'); // Null terminate for C
   auto globalType =
       LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index da01ec496bec7..1283dec2a0a5b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -281,8 +281,8 @@ class VectorTransferReadOpConverter
       return failure();
     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
-        adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
-        adaptor.in_bounds());
+        adaptor.permutation_mapAttr(), adaptor.padding(), adaptor.mask(),
+        adaptor.in_boundsAttr());
     return success();
   }
 };
@@ -299,8 +299,8 @@ class VectorTransferWriteOpConverter
       return failure();
     rewriter.create<vector::TransferWriteOp>(
         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
-        adaptor.permutation_map(),
-        adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
+        adaptor.permutation_mapAttr(),
+        adaptor.in_bounds() ? adaptor.in_boundsAttr() : ArrayAttr());
     rewriter.replaceOp(writeOp, adaptor.source());
     return success();
   }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 69da41c343709..303e433e1b16b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1206,7 +1206,7 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     inWidth = inputShape.getDimSize(2);
   }
 
-  int32_t shift = adaptor.shift().getValue().getSExtValue();
+  int32_t shift = adaptor.shift();
   llvm::SmallVector<int64_t> newShape;
   getI64Values(adaptor.output_size(), newShape);
   outputShape[1] = newShape[0];

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 5ae0e9b18af01..3b74275e85a9a 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -59,8 +59,10 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   ::mlir::ValueRange getODSOperands(unsigned index);
 // CHECK:   ::mlir::Value getA();
 // CHECK:   ::mlir::ValueRange getB();
-// CHECK:   ::mlir::IntegerAttr getAttr1();
-// CHECK:   ::mlir::FloatAttr getAttr2();
+// CHECK:   ::mlir::IntegerAttr getAttr1Attr();
+// CHECK:   uint32_t getAttr1();
+// CHECK:   ::mlir::FloatAttr getAttr2Attr();
+// CHECK:   ::llvm::Optional< ::llvm::APFloat > getAttr2();
 // CHECK:   ::mlir::Region &getSomeRegion();
 // CHECK:   ::mlir::RegionRange getSomeRegions();
 // CHECK: private:

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 638369e90c3a1..f024b90d33404 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -715,6 +715,33 @@ void OpEmitter::genAttrNameGetters() {
   }
 }
 
+// Emit the getter for an attribute with the return type specified.
+// It is templated to be shared between the Op and the adaptor class.
+template <typename OpClassOrAdaptor>
+static void emitAttrGetterWithReturnType(FmtContext &fctx,
+                                         OpClassOrAdaptor &opClass,
+                                         const Operator &op, StringRef name,
+                                         Attribute attr) {
+  auto *method = opClass.addMethod(attr.getReturnType(), name);
+  ERROR_IF_PRUNED(method, name, op);
+  auto &body = method->body();
+  body << "  auto attr = " << name << "Attr();\n";
+  if (attr.hasDefaultValue()) {
+    // Returns the default value if not set.
+    // TODO: this is inefficient, we are recreating the attribute for every
+    // call. This should be set instead.
+    std::string defaultValue = std::string(
+        tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+    body << "    if (!attr)\n      return "
+         << tgfmt(attr.getConvertFromStorageCall(),
+                  &fctx.withSelf(defaultValue))
+         << ";\n";
+  }
+  body << "  return "
+       << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
+       << ";\n";
+}
+
 void OpEmitter::genAttrGetters() {
   FmtContext fctx;
   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
@@ -725,28 +752,6 @@ void OpEmitter::genAttrGetters() {
       method->body() << "  " << attr.getDerivedCodeBody() << "\n";
   };
 
-  // Emit with return type specified.
-  auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
-    auto *method = opClass.addMethod(attr.getReturnType(), name);
-    ERROR_IF_PRUNED(method, name, op);
-    auto &body = method->body();
-    body << "  auto attr = " << name << "Attr();\n";
-    if (attr.hasDefaultValue()) {
-      // Returns the default value if not set.
-      // TODO: this is inefficient, we are recreating the attribute for every
-      // call. This should be set instead.
-      std::string defaultValue = std::string(
-          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
-      body << "    if (!attr)\n      return "
-           << tgfmt(attr.getConvertFromStorageCall(),
-                    &fctx.withSelf(defaultValue))
-           << ";\n";
-    }
-    body << "  return "
-         << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
-         << ";\n";
-  };
-
   // Generate named accessor with Attribute return type. This is a wrapper class
   // that allows referring to the attributes via accessors instead of having to
   // use the string interface for better compile time verification.
@@ -767,7 +772,7 @@ void OpEmitter::genAttrGetters() {
         emitDerivedAttr(name, namedAttr.attr);
       } else {
         emitAttrWithStorageType(name, namedAttr.attr);
-        emitAttrWithReturnType(name, namedAttr.attr);
+        emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
       }
     }
   }
@@ -2588,9 +2593,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
   FmtContext fctx;
   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
 
-  auto emitAttr = [&](StringRef name, StringRef emitName, Attribute attr) {
-    auto *method = adaptor.addMethod(attr.getStorageType(), emitName);
-    ERROR_IF_PRUNED(method, "Adaptor::" + emitName, op);
+  // Generate named accessor with Attribute return type.
+  auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
+                                     Attribute attr) {
+    auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
+    ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
     auto &body = method->body();
     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
          << "\n  " << attr.getStorageType() << " attr = "
@@ -2620,9 +2627,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
   for (auto &namedAttr : op.getAttributes()) {
     const auto &name = namedAttr.name;
     const auto &attr = namedAttr.attr;
-    if (!attr.isDerivedAttr()) {
-      for (const auto &emitName : op.getGetterNames(name))
-        emitAttr(name, emitName, attr);
+    if (attr.isDerivedAttr())
+      continue;
+    for (const auto &emitName : op.getGetterNames(name)) {
+      emitAttrWithStorageType(name, emitName, attr);
+      emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
     }
   }
 


        


More information about the Mlir-commits mailing list