[Mlir-commits] [mlir] faf1c22 - [Builder] Eliminate the StringRef/StringAttr forms of getSymbolRefAttr.

Chris Lattner llvmlistbot at llvm.org
Mon Aug 30 16:11:51 PDT 2021


Author: Chris Lattner
Date: 2021-08-30T16:05:36-07:00
New Revision: faf1c22408cfea27749632d880e0a9c0d6b7a568

URL: https://github.com/llvm/llvm-project/commit/faf1c22408cfea27749632d880e0a9c0d6b7a568
DIFF: https://github.com/llvm/llvm-project/commit/faf1c22408cfea27749632d880e0a9c0d6b7a568.diff

LOG: [Builder] Eliminate the StringRef/StringAttr forms of getSymbolRefAttr.

The StringAttr version doesn't need a context, so we can just use the
existing `SymbolRefAttr::get` form.  The StringRef version isn't preferred
so we want to encourage people to use StringAttr.

There is an additional form of getSymbolRefAttr that takes a (SymbolTrait
implementing) operation.  This should also be moved, but I'll do that as
a separate patch.

Differential Revision: https://reviews.llvm.org/D108922

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Lower/IntrinsicCall.cpp
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/examples/toy/Ch2/mlir/Dialect.cpp
    mlir/examples/toy/Ch3/mlir/Dialect.cpp
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
    mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Transforms/NormalizeMemRefs.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2d756fc1e961f..300a5b6614f54 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2545,7 +2545,7 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
     [{
       $_state.addOperands(operands);
       $_state.addAttribute(calleeAttrName($_state.name),
-        $_builder.getSymbolRefAttr(callee));
+                           SymbolRefAttr::get(callee));
       $_state.addTypes(callee.getType().getResults());
     }]>,
     OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
@@ -2560,7 +2560,8 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
         "llvm::ArrayRef<mlir::Type>":$results,
         CArg<"mlir::ValueRange", "{}">:$operands),
     [{
-      build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results,
+      build($_builder, $_state,
+            SymbolRefAttr::get($_builder.getContext(), callee), results,
             operands);
     }]>];
 

diff  --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp
index b9c2bba03631c..2d00e9098b416 100644
--- a/flang/lib/Lower/IntrinsicCall.cpp
+++ b/flang/lib/Lower/IntrinsicCall.cpp
@@ -919,7 +919,7 @@ mlir::SymbolRefAttr IntrinsicLibrary::getUnrestrictedIntrinsicSymbolRefAttr(
     funcOp = getWrapper(rtCallGenerator, name, signature, loadRefArguments);
   }
 
-  return builder.getSymbolRefAttr(funcOp.getName());
+  return SymbolRefAttr::get(funcOp);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 789964c7f8968..c10f45e9847a9 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -398,8 +398,7 @@ mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser,
 //===----------------------------------------------------------------------===//
 
 void fir::ConvertOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-}
+    OwningRewritePatternList &results, MLIRContext *context) {}
 
 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
   if (value().getType() == getType())
@@ -629,7 +628,8 @@ void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
   result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type));
   result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
-  result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name));
+  result.addAttribute(symbolAttrName(),
+                      SymbolRefAttr::get(builder.getContext(), name));
   if (isConstant)
     result.addAttribute(constantAttrName(result.name), builder.getUnitAttr());
   if (initialVal)
@@ -1330,7 +1330,7 @@ static constexpr llvm::StringRef getTargetOffsetAttr() {
 template <typename A, typename... AdditionalArgs>
 static A getSubOperands(unsigned pos, A allArgs,
                         mlir::DenseIntElementsAttr ranges,
-                        AdditionalArgs &&... additionalArgs) {
+                        AdditionalArgs &&...additionalArgs) {
   unsigned start = 0;
   for (unsigned i = 0; i < pos; ++i)
     start += (*(ranges.begin() + i)).getZExtValue();

diff  --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index 5213d336d5482..9704fa4a4d8d1 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -174,7 +174,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
   // Generic call always returns an unranked Tensor initially.
   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
   state.addOperands(arguments);
-  state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+  state.addAttribute("callee",
+                     mlir::SymbolRefAttr::get(builder.getContext(), callee));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 5213d336d5482..9704fa4a4d8d1 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -174,7 +174,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
   // Generic call always returns an unranked Tensor initially.
   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
   state.addOperands(arguments);
-  state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+  state.addAttribute("callee",
+                     mlir::SymbolRefAttr::get(builder.getContext(), callee));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index ff1d4cdcd2a9a..57528611ca3c9 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -256,7 +256,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
   // Generic call always returns an unranked Tensor initially.
   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
   state.addOperands(arguments);
-  state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+  state.addAttribute("callee",
+                     mlir::SymbolRefAttr::get(builder.getContext(), callee));
 }
 
 /// Return the callee of the generic call operation, this is required by the

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 89e7529a4a948..94ab83208c1c9 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -256,7 +256,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
   // Generic call always returns an unranked Tensor initially.
   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
   state.addOperands(arguments);
-  state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+  state.addAttribute("callee",
+                     mlir::SymbolRefAttr::get(builder.getContext(), callee));
 }
 
 /// Return the callee of the generic call operation, this is required by the

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 89e7529a4a948..94ab83208c1c9 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -256,7 +256,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
   // Generic call always returns an unranked Tensor initially.
   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
   state.addOperands(arguments);
-  state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+  state.addAttribute("callee",
+                     mlir::SymbolRefAttr::get(builder.getContext(), callee));
 }
 
 /// Return the callee of the generic call operation, this is required by the

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 30d473f7bec20..a0acff11a30d1 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -282,7 +282,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
   // Generic call always returns an unranked Tensor initially.
   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
   state.addOperands(arguments);
-  state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+  state.addAttribute("callee",
+                     mlir::SymbolRefAttr::get(builder.getContext(), callee));
 }
 
 /// Return the callee of the generic call operation, this is required by the

diff  --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index 4dca519c3e80a..f4e8ced97af7c 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -522,7 +522,7 @@ class MLIRGenImpl {
     mlir::FuncOp calledFunc = calledFuncIt->second;
     return builder.create<GenericCallOp>(
         location, calledFunc.getType().getResult(0),
-        builder.getSymbolRefAttr(callee), operands);
+        mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
   }
 
   /// Emit a print expression. It emits specific operations for two builtins:

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8505f6f437170..5163125652f47 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -515,14 +515,22 @@ def LLVM_CallOp : LLVM_Op<"call",
   let results = (outs Variadic<LLVM_Type>);
   let builders = [
     OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
-    [{
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{
       Type resultType = func.getType().getReturnType();
       if (!resultType.isa<LLVM::LLVMVoidType>())
         $_state.addTypes(resultType);
-      $_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
+      $_state.addAttribute("callee", SymbolRefAttr::get(func));
       $_state.addAttributes(attributes);
       $_state.addOperands(operands);
+    }]>,
+    OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
+                   CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, results, SymbolRefAttr::get(callee), operands);
+    }]>,
+    OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+                   CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, results,
+            StringAttr::get($_builder.getContext(), callee), operands);
     }]>];
   let verifier = [{ return ::verify(*this); }];
   let parser = [{ return parseCallOp(parser, result); }];

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b715070706e37..1633c0d98b63d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -560,7 +560,7 @@ def CallOp : Std_Op<"call",
   let builders = [
     OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
       $_state.addOperands(operands);
-      $_state.addAttribute("callee",$_builder.getSymbolRefAttr(callee));
+      $_state.addAttribute("callee", SymbolRefAttr::get(callee));
       $_state.addTypes(callee.getType().getResults());
     }]>,
     OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results,
@@ -569,14 +569,19 @@ def CallOp : Std_Op<"call",
       $_state.addAttribute("callee", callee);
       $_state.addTypes(results);
     }]>,
+    OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results,
+      CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, SymbolRefAttr::get(callee), results, operands);
+    }]>,
     OpBuilder<(ins "StringRef":$callee, "TypeRange":$results,
       CArg<"ValueRange", "{}">:$operands), [{
-      build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results,
-            operands);
+      build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
+            results, operands);
     }]>];
 
   let extraClassDeclaration = [{
     StringRef getCallee() { return callee(); }
+    StringAttr getCalleeAttr() { return calleeAttr().getAttr(); }
     FunctionType getCalleeType();
 
     /// Get the argument operands to the called function.

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7e6aa710e1f94..102fb3dc40e98 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -97,17 +97,6 @@ class Builder {
   FloatAttr getFloatAttr(Type type, const APFloat &value);
   StringAttr getStringAttr(const Twine &bytes);
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
-  FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
-  FlatSymbolRefAttr getSymbolRefAttr(StringAttr value);
-  SymbolRefAttr getSymbolRefAttr(StringAttr value,
-                                 ArrayRef<FlatSymbolRefAttr> nestedReferences);
-  SymbolRefAttr getSymbolRefAttr(StringRef value,
-                                 ArrayRef<FlatSymbolRefAttr> nestedReferences) {
-    return getSymbolRefAttr(getStringAttr(value), nestedReferences);
-  }
-  FlatSymbolRefAttr getSymbolRefAttr(StringRef value) {
-    return getSymbolRefAttr(getStringAttr(value));
-  }
 
   // Returns a 0-valued attribute of the given `type`. This function only
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 0240e17e83419..e0ede99b19af6 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -23,6 +23,7 @@ class FunctionType;
 class IntegerSet;
 class IntegerType;
 class Location;
+class Operation;
 class ShapedType;
 
 //===----------------------------------------------------------------------===//
@@ -685,12 +686,17 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
   using ValueType = StringRef;
 
   /// Construct a symbol reference for the given value name.
+  static FlatSymbolRefAttr get(StringAttr value) {
+    return SymbolRefAttr::get(value);
+  }
   static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
     return SymbolRefAttr::get(ctx, value);
   }
 
-  static FlatSymbolRefAttr get(StringAttr value) {
-    return SymbolRefAttr::get(value);
+  /// Convenience getter for building a SymbolRefAttr based on an operation
+  /// that implements the SymbolTrait.
+  static FlatSymbolRefAttr get(Operation *symbol) {
+    return SymbolRefAttr::get(symbol);
   }
 
   /// Returns the name of the held symbol reference as a StringAttr.

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 08c3d0f2ebade..228f1c6dca992 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -893,8 +893,16 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
     }]>,
   ];
   let extraClassDeclaration = [{
-    static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
+    static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
+                             ArrayRef<FlatSymbolRefAttr> nestedRefs);
+    /// Convenience getters for building a SymbolRefAttr with no path, which is
+    /// known to produce a FlatSymbolRefAttr.
     static FlatSymbolRefAttr get(StringAttr value);
+    static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
+
+    /// Convenience getter for buliding a SymbolRefAttr based on an operation
+    /// that implements the SymbolTrait.
+    static FlatSymbolRefAttr get(Operation *symbol);
 
     /// Returns the name of the fully resolved symbol, i.e. the leaf of the
     /// reference path.

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index facd2a8d77a14..4c8b5e5014e0a 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1582,15 +1582,16 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,
   let storageType = [{ ::mlir::SymbolRefAttr }];
   let returnType = [{ ::mlir::SymbolRefAttr }];
   let valueType = NoneType;
-  let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+  let constBuilderCall = "SymbolRefAttr::get($_builder.getContext(), $0)";
   let convertFromStorage = "$_self";
 }
+
 def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::FlatSymbolRefAttr>()">,
                                    "flat symbol reference attribute"> {
   let storageType = [{ ::mlir::FlatSymbolRefAttr }];
   let returnType = [{ ::llvm::StringRef }];
   let valueType = NoneType;
-  let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+  let constBuilderCall = "SymbolRefAttr::get($_builder.getContext(), $0)";
   let convertFromStorage = "$_self.getValue()";
 }
 

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 8626efdfe4a97..4f4dd0ed5044c 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -367,7 +367,7 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
 
     // Allocate memory for the coroutine frame.
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
-        loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
+        loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc),
         ValueRange(coroSize.getResult()));
 
     // Begin a coroutine: @llvm.coro.begin.
@@ -399,9 +399,9 @@ class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
 
     // Free the memory.
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
-                                              rewriter.getSymbolRefAttr(kFree),
-                                              ValueRange(coroMem.getResult()));
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+        op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree),
+        ValueRange(coroMem.getResult()));
 
     return success();
   }

diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 20b7f9ad448c7..b8781fc68b346 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -62,8 +62,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 
     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
     auto callOp = rewriter.create<LLVM::CallOp>(
-        op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
-        castedOperands);
+        op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
 
     if (resultType == operands.front().getType()) {
       rewriter.replaceOp(op, {callOp.getResult(0)});

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index d648a293400d4..45222e5a9e39c 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -171,14 +171,13 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
 
   // Create vulkan launch call op.
   auto vulkanLaunchCallOp = builder.create<CallOp>(
-      loc, TypeRange{}, builder.getSymbolRefAttr(kVulkanLaunch),
+      loc, TypeRange{}, SymbolRefAttr::get(builder.getContext(), kVulkanLaunch),
       vulkanLaunchOperands);
 
   // Set SPIR-V binary shader data as an attribute.
   vulkanLaunchCallOp->setAttr(
       kSPIRVBlobAttrName,
-      StringAttr::get(loc->getContext(),
-                      StringRef(binary.data(), binary.size())));
+      builder.getStringAttr(StringRef(binary.data(), binary.size())));
 
   // Set entry point name as an attribute.
   vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 933a158aff640..da3e490e44db5 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -248,9 +248,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
     }
     // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
-        loc, TypeRange(),
-        builder.getSymbolRefAttr(
-            StringRef(symbolName.data(), symbolName.size())),
+        loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
         ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
                    ptrToMemRefDescriptor});
   }
@@ -373,8 +371,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
   // Create call to `initVulkan`.
   auto initVulkanCall = builder.create<LLVM::CallOp>(
-      loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan),
-      ValueRange{});
+      loc, TypeRange{getPointerType()}, kInitVulkan);
   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
   // need to pass that pointer to each Vulkan runtime call.
   auto vulkanRuntime = initVulkanCall.getResult(0);
@@ -396,32 +393,29 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // Create call to `setBinaryShader` runtime function with the given pointer to
   // SPIR-V binary and binary size.
   builder.create<LLVM::CallOp>(
-      loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader),
+      loc, TypeRange(), kSetBinaryShader,
       ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
   // Create LLVM global with entry point name.
   Value entryPointName = createEntryPointNameConstant(
       spirvAttributes.second.getValue(), loc, builder);
   // Create call to `setEntryPoint` runtime function with the given pointer to
   // entry point name.
-  builder.create<LLVM::CallOp>(loc, TypeRange(),
-                               builder.getSymbolRefAttr(kSetEntryPoint),
+  builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
                                ValueRange{vulkanRuntime, entryPointName});
 
   // Create number of local workgroup for each dimension.
   builder.create<LLVM::CallOp>(
-      loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups),
+      loc, TypeRange(), kSetNumWorkGroups,
       ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
                  cInterfaceVulkanLaunchCallOp.getOperand(1),
                  cInterfaceVulkanLaunchCallOp.getOperand(2)});
 
   // Create call to `runOnVulkan` runtime function.
-  builder.create<LLVM::CallOp>(loc, TypeRange(),
-                               builder.getSymbolRefAttr(kRunOnVulkan),
+  builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
                                ValueRange{vulkanRuntime});
 
   // Create call to 'deinitVulkan' runtime function.
-  builder.create<LLVM::CallOp>(loc, TypeRange(),
-                               builder.getSymbolRefAttr(kDeinitVulkan),
+  builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
                                ValueRange{vulkanRuntime});
 
   // Declare runtime functions.

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 4a0dbf17ca2de..6c594901bd667 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -50,7 +50,8 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
   }
 
   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
-  FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+  FlatSymbolRefAttr fnNameAttr =
+      SymbolRefAttr::get(rewriter.getContext(), fnName);
   auto module = op->getParentOfType<ModuleOp>();
   if (module.lookupSymbol(fnNameAttr.getAttr()))
     return fnNameAttr;

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index e3aaef5cf0cbd..ea3c9943f4b46 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -305,7 +305,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
         op.getLoc(), getVoidPtrType(),
         memref.allocatedPtr(rewriter, op.getLoc()));
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
-        op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
+        op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 0ea5119041ec0..d05586ddc40a3 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -559,9 +559,10 @@ SymbolRefAttr PatternLowering::generateRewriter(
       /*results=*/llvm::None));
 
   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
-  return builder.getSymbolRefAttr(
+  return SymbolRefAttr::get(
+      builder.getContext(),
       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
-      builder.getSymbolRefAttr(rewriterFunc));
+      SymbolRefAttr::get(rewriterFunc));
 }
 
 void PatternLowering::generateRewriter(

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3b474113b51e1..67583f9a74795 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1194,8 +1194,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   // Helper to emit a call.
   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
                        Operation *ref, ValueRange params = ValueRange()) {
-    rewriter.create<LLVM::CallOp>(loc, TypeRange(),
-                                  rewriter.getSymbolRefAttr(ref), params);
+    rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
+                                  params);
   }
 };
 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b8a184fea3c3a..7f2622f5c707b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -542,8 +542,9 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                       blockSize.y, blockSize.z});
   result.addOperands(kernelOperands);
   auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
-  auto kernelSymbol = builder.getSymbolRefAttr(
-      kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())});
+  auto kernelSymbol =
+      SymbolRefAttr::get(kernelModule.getNameAttr(),
+                         {SymbolRefAttr::get(kernelFunc.getNameAttr())});
   result.addAttribute(getKernelAttrName(), kernelSymbol);
   SmallVector<int32_t, 8> segmentSizes(8, 1);
   segmentSizes.front() = 0; // Initially no async dependencies.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index a4c8b741a3884..c3f8fcb422402 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -129,7 +129,7 @@ Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
                                                    ValueRange paramTypes,
                                                    ArrayRef<Type> resultTypes) {
   return b
-      .create<LLVM::CallOp>(loc, resultTypes, b.getSymbolRefAttr(fn),
+      .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
                             paramTypes)
       ->getResults();
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5f4d83286c712..877f4b8bf52dd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1060,7 +1060,7 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
 
 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
                                spirv::GlobalVariableOp var) {
-  build(builder, state, var.type(), builder.getSymbolRefAttr(var));
+  build(builder, state, var.type(), SymbolRefAttr::get(var));
 }
 
 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
@@ -1712,8 +1712,7 @@ void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
                                 ArrayRef<Attribute> interfaceVars) {
   build(builder, state,
         spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
-        builder.getSymbolRefAttr(function),
-        builder.getArrayAttr(interfaceVars));
+        SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
 }
 
 static ParseResult parseEntryPointOp(OpAsmParser &parser,
@@ -1772,7 +1771,7 @@ void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
                                    spirv::FuncOp function,
                                    spirv::ExecutionMode executionMode,
                                    ArrayRef<int32_t> params) {
-  build(builder, state, builder.getSymbolRefAttr(function),
+  build(builder, state, SymbolRefAttr::get(function),
         spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
         builder.getI32ArrayAttr(params));
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index fad437ce1330f..fe6433f6777ee 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -68,7 +68,7 @@ class SPIRVAddressOfOpLayoutInfoDecoration
     auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
 
     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
-        op, varOp.type(), rewriter.getSymbolRefAttr(varName.getAttr()));
+        op, varOp.type(), SymbolRefAttr::get(varName.getAttr()));
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
index a3dd9a4be5ec8..7636bc78d6772 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
@@ -156,7 +156,7 @@ struct DecomposeCallGraphTypesForCallOp
         resultMapping.push_back(i);
     }
 
-    CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCallee(),
+    CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
                                                newResultTypes, newOperands);
 
     // Build a replacement value for each result to replace its uses. If a

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 0ced5e55a5183..775d0c40c53c5 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -210,23 +210,6 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
   return ArrayAttr::get(context, value);
 }
 
-FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
-  auto symName =
-      value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
-  assert(symName && "value does not have a valid symbol name");
-  return getSymbolRefAttr(symName.getValue());
-}
-
-FlatSymbolRefAttr Builder::getSymbolRefAttr(StringAttr value) {
-  return SymbolRefAttr::get(value);
-}
-
-SymbolRefAttr
-Builder::getSymbolRefAttr(StringAttr value,
-                          ArrayRef<FlatSymbolRefAttr> nestedReferences) {
-  return SymbolRefAttr::get(value, nestedReferences);
-}
-
 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
   auto attrs = llvm::to_vector<8>(llvm::map_range(
       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e9e1ed8c25452..7faac153510c1 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,14 +10,14 @@
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/Twine.h"
 #include "llvm/Support/Endian.h"
 
 using namespace mlir;
@@ -272,14 +272,26 @@ LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // SymbolRefAttr
 //===----------------------------------------------------------------------===//
 
+SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
+                                 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
+  return get(StringAttr::get(ctx, value), nestedRefs);
+}
+
 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
-  return get(StringAttr::get(ctx, value));
+  return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
 }
 
 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
   return get(value, {}).cast<FlatSymbolRefAttr>();
 }
 
+FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
+  auto symName =
+      symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
+  assert(symName && "value does not have a valid symbol name");
+  return SymbolRefAttr::get(symName);
+}
+
 StringAttr SymbolRefAttr::getLeafReference() const {
   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 38f86155d4b71..1e9e87bdb7e5a 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -191,7 +191,8 @@ Attribute Parser::parseAttribute(Type type) {
       consumeToken(Token::at_identifier);
       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
     }
-    SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs);
+    SymbolRefAttr symbolRefAttr =
+        SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
 
     // If we are populating the assembly state, record this symbol reference.
     if (state.asmState)

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 634d31a8c3965..aed89925ad6e4 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1406,9 +1406,8 @@ class CustomOpAsmParser : public OpAsmParser {
     // If we are populating the assembly parser state, record this as a symbol
     // reference.
     if (parser.getState().asmState) {
-      parser.getState().asmState->addUses(
-          getBuilder().getSymbolRefAttr(result.getValue()),
-          atToken.getLocRange());
+      parser.getState().asmState->addUses(SymbolRefAttr::get(result),
+                                          atToken.getLocRange());
     }
     return success();
   }

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 010594f360cf3..3945cfa6ee0f7 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -245,7 +245,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
   }
   if (auto *f = dyn_cast<llvm::Function>(value))
-    return b.getSymbolRefAttr(f->getName());
+    return SymbolRefAttr::get(b.getContext(), f->getName());
 
   // Convert constant data to a dense elements attribute.
   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
@@ -668,8 +668,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     }
     Operation *op;
     if (llvm::Function *callee = ci->getCalledFunction()) {
-      op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
-                            ops);
+      op = b.create<CallOp>(
+          loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops);
     } else {
       Value calledValue = processValue(ci->getCalledOperand());
       if (!calledValue)
@@ -713,9 +713,10 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
 
     Operation *op;
     if (llvm::Function *callee = ii->getCalledFunction()) {
-      op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
-                              ops, blocks[ii->getNormalDest()], normalArgs,
-                              blocks[ii->getUnwindDest()], unwindArgs);
+      op = b.create<InvokeOp>(
+          loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops,
+          blocks[ii->getNormalDest()], normalArgs, blocks[ii->getUnwindDest()],
+          unwindArgs);
     } else {
       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
@@ -771,7 +772,7 @@ FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
 
   // If it directly has a name, we can use it.
   if (pf->hasName())
-    return b.getSymbolRefAttr(pf->getName());
+    return SymbolRefAttr::get(b.getContext(), pf->getName());
 
   // If it doesn't have a name, currently, only function pointers that are
   // bitcast to i8* are parsed.
@@ -779,7 +780,7 @@ FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
     if (ce->getOpcode() == llvm::Instruction::BitCast &&
         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
-        return b.getSymbolRefAttr(func->getName());
+        return SymbolRefAttr::get(b.getContext(), func->getName());
     }
   }
   return FlatSymbolRefAttr();

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 6137fee34dcd7..c01362d1e2d4c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -44,20 +44,19 @@ Value spirv::Deserializer::getValue(uint32_t id) {
   }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
-        unknownLoc, varOp.type(),
-        opBuilder.getSymbolRefAttr(varOp.getOperation()));
+        unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
     return addressOfOp.pointer();
   }
   if (auto constOp = getSpecConstant(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
         unknownLoc, constOp.default_value().getType(),
-        opBuilder.getSymbolRefAttr(constOp.getOperation()));
+        SymbolRefAttr::get(constOp.getOperation()));
     return referenceOfOp.reference();
   }
   if (auto constCompositeOp = getSpecConstantComposite(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
         unknownLoc, constCompositeOp.type(),
-        opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
+        SymbolRefAttr::get(constCompositeOp.getOperation()));
     return referenceOfOp.reference();
   }
   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -357,12 +356,12 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
       return emitError(unknownLoc, "undefined result <id> ")
              << words[wordIndex] << " while decoding OpEntryPoint";
     }
-    interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
+    interface.push_back(SymbolRefAttr::get(arg.getOperation()));
     wordIndex++;
   }
-  opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
-                                        opBuilder.getSymbolRefAttr(fnName),
-                                        opBuilder.getArrayAttr(interface));
+  opBuilder.create<spirv::EntryPointOp>(
+      unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
+      opBuilder.getArrayAttr(interface));
   return success();
 }
 
@@ -394,7 +393,8 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
   }
   auto values = opBuilder.getArrayAttr(attrListElems);
   opBuilder.create<spirv::ExecutionModeOp>(
-      unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
+      unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
+      execMode, values);
   return success();
 }
 
@@ -461,8 +461,8 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
   }
 
   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
-      unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
-      arguments);
+      unknownLoc, resultType,
+      SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
 
   if (resultType)
     valueMap[resultID] = opFunctionCall.getResult(0);

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 132e23283704a..1fbfbf09c066e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -575,7 +575,7 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
              << operands[wordIndex] << "used as initializer";
     }
     wordIndex++;
-    initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
+    initializer = SymbolRefAttr::get(initializerOp.getOperation());
   }
   if (wordIndex != operands.size()) {
     return emitError(unknownLoc,
@@ -1279,7 +1279,7 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   elements.reserve(operands.size() - 2);
   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
     auto elementInfo = getSpecConstant(operands[i]);
-    elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
+    elements.push_back(SymbolRefAttr::get(elementInfo));
   }
 
   auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index ff0fdc95f45e8..c2b3a956d997e 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -129,10 +129,10 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
 
   // Functions called by this function.
   funcOp.walk([&](CallOp callOp) {
-    StringRef callee = callOp.getCallee();
+    StringAttr callee = callOp.getCalleeAttr();
     for (FuncOp &funcOp : normalizableFuncs) {
       // We compare FuncOp and callee's name.
-      if (callee == funcOp.getName()) {
+      if (callee == funcOp.getNameAttr()) {
         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
                                             normalizableFuncs);
         break;
@@ -255,10 +255,9 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
     auto callOp = dyn_cast<CallOp>(userOp);
     if (!callOp)
       continue;
-    StringRef callee = callOp.getCallee();
-    Operation *newCallOp = builder.create<CallOp>(
-        userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
-        userOp->getOperands());
+    Operation *newCallOp =
+        builder.create<CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
+                               resultTypes, userOp->getOperands());
     bool replacingMemRefUsesFailed = false;
     bool returnTypeChanged = false;
     for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {


        


More information about the Mlir-commits mailing list