[Mlir-commits] [mlir] baca1c1 - [mlir][ods] Make Attr/Type def accessors match the dialect

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 13 22:13:29 PDT 2022


Author: Mogball
Date: 2022-06-14T05:13:24Z
New Revision: baca1c1ac41ebb11cb93c82188d5cc7b203a2f5b

URL: https://github.com/llvm/llvm-project/commit/baca1c1ac41ebb11cb93c82188d5cc7b203a2f5b
DIFF: https://github.com/llvm/llvm-project/commit/baca1c1ac41ebb11cb93c82188d5cc7b203a2f5b.diff

LOG: [mlir][ods] Make Attr/Type def accessors match the dialect

The generated attribute and type def accessors are changed to match the setting on the dialect. Most importantly, "prefixed" will now correctly convert snake case to camel case (e.g. `weight_zp` -> `getWeightZp`)

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D127688

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index c2aafca0831b0..0a23e0fed56cc 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -58,6 +58,9 @@ class AttrOrTypeParameter {
   /// Get the parameter name.
   StringRef getName() const;
 
+  /// Get the parameter accessor name.
+  std::string getAccessorName() const;
+
   /// If specified, get the custom allocator code for this parameter.
   Optional<StringRef> getAllocator() const;
 

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b1d0b696009da..6a14e6a0f4d47 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -147,8 +147,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
       cast<tosa::NegateOp>(op).quantization_info()) {
     auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    int64_t inZp = quantizationInfo.getValue().getInput_zp();
-    int64_t outZp = quantizationInfo.getValue().getOutput_zp();
+    int64_t inZp = quantizationInfo.getValue().getInputZp();
+    int64_t outZp = quantizationInfo.getValue().getOutputZp();
 
     // Compute the maximum value that can occur in the intermediate buffer.
     int64_t zpAdd = inZp + outZp;
@@ -1847,7 +1847,7 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
       } else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
       } else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
-        int64_t value = padOp.quantization_info().getValue().getInput_zp();
+        int64_t value = padOp.quantization_info().getValue().getInputZp();
         constantAttr = rewriter.getIntegerAttr(elementTy, value);
       }
       if (constantAttr)

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2154cd98204e0..bd3eb0feca647 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -202,7 +202,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      int64_t iZp = quantizationInfo.getInput_zp();
+      int64_t iZp = quantizationInfo.getInputZp();
 
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -274,8 +274,8 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
-      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
+      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
+      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -366,8 +366,8 @@ class DepthwiseConvConverter
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
-      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
+      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
+      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
     }
 
     auto weightShape = weightTy.getShape();
@@ -378,7 +378,7 @@ class DepthwiseConvConverter
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-      int64_t iZp = quantizationInfo.getInput_zp();
+      int64_t iZp = quantizationInfo.getInputZp();
 
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -542,9 +542,9 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
 
     auto quantizationInfo = op.quantization_info().getValue();
     auto aZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
     auto bZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
@@ -652,9 +652,9 @@ class FullyConnectedConverter
 
     auto quantizationInfo = op.quantization_info().getValue();
     auto inputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
     auto outputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()));
+        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
     Value matmul =
         rewriter
             .create<linalg::QuantizedMatmulOp>(
@@ -892,8 +892,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             if (op.quantization_info()) {
               auto quantizationInfo = op.quantization_info().getValue();
               auto inputZp = rewriter.create<arith::ConstantOp>(
-                  loc,
-                  b.getIntegerAttr(accETy, quantizationInfo.getInput_zp()));
+                  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
               poolVal =
@@ -930,7 +929,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
               auto quantizationInfo = op.quantization_info().getValue();
               auto outputZp = rewriter.create<arith::ConstantOp>(
                   loc, b.getIntegerAttr(scaled.getType(),
-                                        quantizationInfo.getOutput_zp()));
+                                        quantizationInfo.getOutputZp()));
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 17f3412999f3f..b588afa106476 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -145,7 +145,7 @@ spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
 
 DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
   if (auto entryPoint = spirv::lookupEntryPointABI(op))
-    return entryPoint.getLocal_size();
+    return entryPoint.getLocalSize();
 
   return {};
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 3a01557f5c9de..3ea2224a4408a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -135,7 +135,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
       funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars);
 
   // Specifies the spv.ExecutionModeOp.
-  auto localSizeAttr = entryPointAttr.getLocal_size();
+  auto localSizeAttr = entryPointAttr.getLocalSize();
   if (localSizeAttr) {
     auto values = localSizeAttr.getValues<int32_t>();
     SmallVector<int32_t, 3> localSize(values);

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d84e5d218dc2a..d7203c6afbe46 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -347,7 +347,7 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     } else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
     } else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
-      auto value = op.quantization_info().getValue().getInput_zp();
+      auto value = op.quantization_info().getValue().getInputZp();
       constantAttr = rewriter.getIntegerAttr(elementTy, value);
     }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 3389dda46e1b0..1db101280ef23 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -214,7 +214,7 @@ class TransposeConvStridedConverter
       weight = createOpAndInfer<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
           weightPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeight_zp()));
+          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
 
     } else {
       weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
@@ -278,7 +278,7 @@ class TransposeConvStridedConverter
       input = createOpAndInfer<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
           inputPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInput_zp()));
+          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
     } else {
       input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
                                             UnrankedTensorType::get(inputETy),

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 444db742bd32e..8467af0ee74c0 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -215,6 +215,11 @@ StringRef AttrOrTypeParameter::getName() const {
   return def->getArgName(index)->getValue();
 }
 
+std::string AttrOrTypeParameter::getAccessorName() const {
+  return "get" +
+         llvm::convertToCamelFromSnakeCase(getName(), /*capitalizeFirst=*/true);
+}
+
 Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
   return getDefValue<llvm::StringInit>("allocator");
 }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 6e32b431bf823..6895cb15fcdb5 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -7,16 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "AttrOrTypeFormatGen.h"
-#include "mlir/Support/LogicalResult.h"
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Class.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Interfaces.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/TableGen/Error.h"
@@ -31,13 +27,6 @@ using namespace mlir::tblgen;
 // Utility Functions
 //===----------------------------------------------------------------------===//
 
-std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
-  assert(!name.empty() && "parameter has empty name");
-  auto ret = "get" + name.str();
-  ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
-  return ret;
-}
-
 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
 /// specified and can only find one dialect's defs, use that.
 static void collectAllDefs(StringRef selectedDialect,
@@ -288,7 +277,7 @@ void DefGen::emitParserPrinter() {
 void DefGen::emitAccessors() {
   for (auto &param : params) {
     Method *m = defCls.addMethod(
-        param.getCppAccessorType(), getParameterAccessorName(param.getName()),
+        param.getCppAccessorType(), param.getAccessorName(),
         def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
     // Generate accessor definitions only if we also generate the storage
     // class. Otherwise, let the user define the exact accessor definition.

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index e1b0774e2a2b8..a31943790aada 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -58,7 +58,7 @@ class ParameterElement
 
   /// Generate the code to check whether the parameter should be printed.
   MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
-    std::string self = getParameterAccessorName(getName()) + "()";
+    std::string self = param.getAccessorName() + "()";
     ctx.withSelf(self);
     os << tgfmt("($_self", &ctx);
     if (llvm::Optional<StringRef> defaultValue = getParam().getDefaultValue()) {
@@ -718,7 +718,7 @@ void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
                                    MethodBody &os, bool skipGuard) {
   const AttrOrTypeParameter &param = el->getParam();
-  ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+  ctx.withSelf(param.getAccessorName() + "()");
 
   // Guard the printer on the presence of optional parameters and that they
   // aren't equal to their default values (if they have one).
@@ -812,8 +812,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
     if (auto *ref = dyn_cast<RefDirective>(arg))
       param = ref->getArg();
     os << ",\n"
-       << getParameterAccessorName(cast<ParameterElement>(param)->getName())
-       << "()";
+       << cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
   }
   os.unindent() << ");\n";
 }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
index c371aee268b42..d4711532a79bb 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
@@ -20,11 +20,6 @@ class AttrOrTypeDef;
 void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
                               MethodBody &printer);
 
-/// From the parameter name, get the name of the accessor function in camelcase.
-/// The first letter of the parameter is upper-cased and prefixed with "get".
-/// E.g. 'value' -> 'getValue'.
-std::string getParameterAccessorName(llvm::StringRef name);
-
 } // namespace tblgen
 } // namespace mlir
 


        


More information about the Mlir-commits mailing list