[flang-commits] [flang] 7827753 - Reorder MLIRContext location in BuiltinAttributes.h
Tres Popp via flang-commits
flang-commits at lists.llvm.org
Mon Feb 8 00:28:21 PST 2021
Author: Tres Popp
Date: 2021-02-08T09:28:09+01:00
New Revision: 7827753f9810e846fb702f3e8dcff0bfb37344e1
URL: https://github.com/llvm/llvm-project/commit/7827753f9810e846fb702f3e8dcff0bfb37344e1
DIFF: https://github.com/llvm/llvm-project/commit/7827753f9810e846fb702f3e8dcff0bfb37344e1.diff
LOG: Reorder MLIRContext location in BuiltinAttributes.h
Now the context is the first, rather than the last input.
This better matches the rest of the infrastructure and makes
it easier to move these types to being declaratively specified.
Differential Revision: https://reviews.llvm.org/D96111
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Lower/FIRBuilder.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/StructsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8f3670b29d74..cde53725b4a4 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -267,27 +267,27 @@ class fir_AllocatableOp<string mnemonic, list<OpTrait> traits = []> :
static constexpr llvm::StringRef inType() { return "in_type"; }
static constexpr llvm::StringRef lenpName() { return "len_param_count"; }
mlir::Type getAllocatedType();
-
+
bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; }
-
+
unsigned numLenParams() {
if (auto val = (*this)->getAttrOfType<mlir::IntegerAttr>(lenpName()))
return val.getInt();
return 0;
}
-
+
operand_range getLenParams() {
return {operand_begin(), operand_begin() + numLenParams()};
}
-
+
unsigned numShapeOperands() {
return operand_end() - operand_begin() + numLenParams();
}
-
+
operand_range getShapeOperands() {
return {operand_begin() + numLenParams(), operand_end()};
}
-
+
static mlir::Type getRefTy(mlir::Type ty);
/// Get the input type of the allocation
@@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> {
}];
let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len);
-
+
let results = (outs fir_BoxCharType);
let assemblyFormat = [{
@@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
p.printFunctionalType((*this)->getOperandTypes(),
(*this)->getResultTypes());
}];
-
+
let verifier = [{
auto refTy = ref().getType();
if (fir::isa_ref_type(refTy)) {
@@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
-
+
let extraClassDeclaration = [{
static constexpr llvm::StringRef baseType() { return "base_type"; }
mlir::Type getBaseType();
@@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
let printer = [{
p << getOperationName() << ' '
- << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
+ << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
<< ", " << (*this)->getAttr(typeAttrName());
if (getNumOperands()) {
p << '(';
@@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
CArg<"ValueRange", "llvm::None">:$iterArgs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
-
+
let extraClassDeclaration = [{
mlir::Block *getBody() { return ®ion().front(); }
mlir::Value getIterateVar() { return getBody()->getArgument(1); }
@@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> {
}];
let arguments = (ins FirRealAttr:$constant);
-
+
let results = (outs fir_RealType:$res);
let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)";
-
+
let verifier = [{
if (!getType().isa<fir::RealType>())
return emitOpError("must be a !fir.real type");
@@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
}];
let results = (outs fir_ComplexType);
-
+
let parser = [{
fir::RealAttr realp;
fir::RealAttr imagp;
@@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc",
def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
let summary = "convert a symbol to an SSA value";
-
+
let description = [{
Convert a symbol (a function or global reference) to an SSA-value to be
used in other Operations.
@@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {
let summary = "encapsulates all Fortran scalar type conversions";
-
+
let description = [{
Generalized type conversion. Convert the ssa value from type T to type U.
Not all pairs of types have conversions. When types T and U are the same
@@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
mlir::Type resultType() {
return fir::AllocaOp::wrapResultType(getType());
}
-
+
/// Return the initializer attribute if it exists, or a null attribute.
Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); }
@@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
}
mlir::FlatSymbolRefAttr getSymbol() {
- return mlir::FlatSymbolRefAttr::get(
+ return mlir::FlatSymbolRefAttr::get(getContext(),
(*this)->getAttrOfType<mlir::StringAttr>(
- mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext());
+ mlir::SymbolTable::getSymbolAttrName()).getValue());
}
}];
}
@@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> {
}];
let printer = [{
- p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
+ p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
<< ", " << (*this)->getAttr(intAttrName());
}];
diff --git a/flang/lib/Lower/FIRBuilder.cpp b/flang/lib/Lower/FIRBuilder.cpp
index 3f470d61c286..0a8473b73268 100644
--- a/flang/lib/Lower/FIRBuilder.cpp
+++ b/flang/lib/Lower/FIRBuilder.cpp
@@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc,
fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit(
mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) {
- auto strAttr = mlir::StringAttr::get(data, getContext());
+ auto strAttr = mlir::StringAttr::get(getContext(), data);
auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext());
mlir::NamedAttribute dataAttr(valTag, strAttr);
auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext());
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 3883ce2ed0c8..8523a8371192 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -107,7 +107,7 @@ class PrintOpLowering : public ConversionPattern {
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
- return SymbolRefAttr::get("printf", context);
+ return SymbolRefAttr::get(context, "printf");
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ class PrintOpLowering : public ConversionPattern {
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
- return SymbolRefAttr::get("printf", context);
+ return SymbolRefAttr::get(context, "printf");
}
/// Return a value representing an access into a global string with the given
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 3883ce2ed0c8..8523a8371192 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -107,7 +107,7 @@ class PrintOpLowering : public ConversionPattern {
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
- return SymbolRefAttr::get("printf", context);
+ return SymbolRefAttr::get(context, "printf");
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ class PrintOpLowering : public ConversionPattern {
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
- return SymbolRefAttr::get("printf", context);
+ return SymbolRefAttr::get(context, "printf");
}
/// Return a value representing an access into a global string with the given
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 794417e99652..b903c0928d1b 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
- auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
+ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
return indexingMaps == maps;
}
@@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
- auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
+ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
return indexingMaps == maps;
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 34e7e8cfce12..571c9126f163 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -69,7 +69,7 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
using Base::Base;
using ValueType = ArrayRef<Attribute>;
- static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
+ static ArrayAttr get(MLIRContext *context, ArrayRef<Attribute> value);
ArrayRef<Attribute> getValue() const;
Attribute operator[](unsigned idx) const;
@@ -126,8 +126,8 @@ class DictionaryAttr
/// attributes. This method assumes that the provided list is unordered. If
/// the caller can guarantee that the attributes are ordered by name,
/// getWithSorted should be used instead.
- static DictionaryAttr get(ArrayRef<NamedAttribute> value,
- MLIRContext *context);
+ static DictionaryAttr get(MLIRContext *context,
+ ArrayRef<NamedAttribute> value);
/// Construct a dictionary with an array of values that is known to already be
/// sorted by name and uniqued.
@@ -250,7 +250,7 @@ class BoolAttr : public Attribute {
using Attribute::Attribute;
using ValueType = bool;
- static BoolAttr get(bool value, MLIRContext *context);
+ static BoolAttr get(MLIRContext *context, bool value);
/// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
/// avoid bringing in all of IntegerAttrs methods.
@@ -292,8 +292,8 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
using Base::Base;
/// Get or create a new OpaqueAttr with the provided dialect and string data.
- static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
- MLIRContext *context);
+ static OpaqueAttr get(MLIRContext *context, Identifier dialect,
+ StringRef attrData, Type type);
/// Get or create a new OpaqueAttr with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
@@ -325,7 +325,7 @@ class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
using ValueType = StringRef;
/// Get an instance of a StringAttr with the given string.
- static StringAttr get(StringRef bytes, MLIRContext *context);
+ static StringAttr get(MLIRContext *context, StringRef bytes);
/// Get an instance of a StringAttr with the given string and Type.
static StringAttr get(StringRef bytes, Type type);
@@ -348,13 +348,12 @@ class SymbolRefAttr
using Base::Base;
/// Construct a symbol reference for the given value name.
- static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
+ static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
/// Construct a symbol reference for the given value name, and a set of nested
/// references that are further resolve to a nested symbol.
- static SymbolRefAttr get(StringRef value,
- ArrayRef<FlatSymbolRefAttr> references,
- MLIRContext *ctx);
+ static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
+ ArrayRef<FlatSymbolRefAttr> references);
/// Returns the name of the top level symbol reference, i.e. the root of the
/// reference path.
@@ -377,8 +376,8 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
using ValueType = StringRef;
/// Construct a symbol reference for the given value name.
- static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
- return SymbolRefAttr::get(value, ctx);
+ static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
+ return SymbolRefAttr::get(ctx, value);
}
/// Returns the name of the held symbol reference.
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index be8a68979203..588a5f7ed62c 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -569,7 +569,7 @@ void FunctionLike<ConcreteType>::setArgAttrs(
if (attributes.empty())
return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
Operation *op = this->getOperation();
- op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+ op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
}
template <typename ConcreteType>
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 45b9c490fd21..70cd55dbbb13 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -315,7 +315,7 @@ class alignas(8) Operation final
attrs = newAttrs;
}
void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
- setAttrs(DictionaryAttr::get(newAttrs, getContext()));
+ setAttrs(DictionaryAttr::get(getContext(), newAttrs));
}
/// Return the specified attribute if present, null otherwise.
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index c5f252e45a20..a7b1fd8cfe64 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
/*defaultImplementation=*/[{
this->getOperation()->setAttr(
mlir::SymbolTable::getSymbolAttrName(),
- StringAttr::get(name, this->getOperation()->getContext()));
+ StringAttr::get(this->getOperation()->getContext(), name));
}]
>,
InterfaceMethod<"Gets the visibility of this symbol.",
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 90ed9cb0ad02..9e61e3a9d6e0 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) {
MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
MlirAttribute const *elements) {
SmallVector<Attribute, 8> attrs;
- return wrap(ArrayAttr::get(
- unwrapList(static_cast<size_t>(numElements), elements, attrs),
- unwrap(ctx)));
+ return wrap(
+ ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
+ elements, attrs)));
}
intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
@@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
attributes.emplace_back(
Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
unwrap(elements[i].attribute));
- return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
+ return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
}
intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
@@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) {
}
MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
- return wrap(BoolAttr::get(value, unwrap(ctx)));
+ return wrap(BoolAttr::get(unwrap(ctx), value));
}
bool mlirBoolAttrGetValue(MlirAttribute attr) {
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type) {
- return wrap(
- OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
- StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
+ return wrap(OpaqueAttr::get(
+ unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
+ StringRef(data, dataLength), unwrap(type)));
}
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
@@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
- return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
+ return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
@@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
- return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
+ return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
}
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
@@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
}
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
- return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
+ return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
}
MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
index 447b00567776..1b9e36180114 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
@@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
auto blob = convertModuleToBlob(llvmModule, loc, name);
if (!blob)
return {};
- return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
+ return StringAttr::get(loc->getContext(), {blob->data(), blob->size()});
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 887d3e798af7..5b62ca455dea 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
// Set SPIR-V binary shader data as an attribute.
vulkanLaunchCallOp->setAttr(
kSPIRVBlobAttrName,
- StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
+ StringAttr::get(loc->getContext(), {binary.data(), binary.size()}));
// Set entry point name as an attribute.
vulkanLaunchCallOp->setAttr(
kSPIRVEntryPointAttrName,
- StringAttr::get(launchOp.getKernelName(), loc->getContext()));
+ StringAttr::get(loc->getContext(), launchOp.getKernelName()));
launchOp.erase();
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 87026e4483e6..29cf42205a56 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -687,8 +687,8 @@ class ExecutionModePattern
rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, executionMode,
- ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
- context));
+ ArrayAttr::get(context,
+ {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
// Insert extra operands if they exist into execution mode info struct.
for (unsigned i = 0, e = values.size(); i < e; ++i) {
@@ -696,9 +696,9 @@ class ExecutionModePattern
Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, entry,
- ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
- rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
- context));
+ ArrayAttr::get(context,
+ {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
}
rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
rewriter.eraseOp(op);
@@ -1297,17 +1297,17 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
switch (funcOp.function_control()) {
#define DISPATCH(functionControl, llvmAttr) \
case functionControl: \
- newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
+ newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
break;
DISPATCH(spirv::FunctionControl::Inline,
- StringAttr::get("alwaysinline", context));
+ StringAttr::get(context, "alwaysinline"));
DISPATCH(spirv::FunctionControl::DontInline,
- StringAttr::get("noinline", context));
+ StringAttr::get(context, "noinline"));
DISPATCH(spirv::FunctionControl::Pure,
- StringAttr::get("readonly", context));
+ StringAttr::get(context, "readonly"));
DISPATCH(spirv::FunctionControl::Const,
- StringAttr::get("readnone", context));
+ StringAttr::get(context, "readnone"));
#undef DISPATCH
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 794f4a5d6c1e..ea0a4259637c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
- StringAttr::get(this->dataLayout, m.getContext()));
+ StringAttr::get(m.getContext(), this->dataLayout));
}
};
} // end namespace
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9e88250e2cab..683de815a54e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -762,7 +762,7 @@ class VectorExtractOpConversion
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
- ArrayAttr::get(positionAttrs.drop_back(), context);
+ ArrayAttr::get(context, positionAttrs.drop_back());
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
@@ -871,7 +871,7 @@ class VectorInsertOpConversion
if (positionAttrs.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
auto nMinusOnePositionAttrs =
- ArrayAttr::get(positionAttrs.drop_back(), context);
+ ArrayAttr::get(context, positionAttrs.drop_back());
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
@@ -887,7 +887,7 @@ class VectorInsertOpConversion
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
- ArrayAttr::get(positionAttrs.drop_back(), context);
+ ArrayAttr::get(context, positionAttrs.drop_back());
inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
adaptor.dest(), inserted,
nMinusOnePositionAttrs);
diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
index c1d0820e1cc7..6ccb59aff35a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
@@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp,
}
ArrayRef<Attribute> mappingAsAttrs(mapping.data(), mapping.size());
ploopOp->setAttr(getMappingAttrName(),
- ArrayAttr::get(mappingAsAttrs, ploopOp.getContext()));
+ ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs));
return success();
}
} // namespace gpu
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a3960ae94b27..e96668779401 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
if (genericAttrNamesSet.count(attr.first.strref()) > 0)
genericAttrs.push_back(attr);
if (!genericAttrs.empty()) {
- auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
+ auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
p << genericDictAttr;
}
@@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
// Handle the corner case of the result being a rank 0 shaped type. Return an
// emtpy ArrayAttr.
if (mapsConsumer.empty() && !mapsProducer.empty())
- return ArrayAttr::get(ArrayRef<Attribute>(), context);
+ return ArrayAttr::get(context, ArrayRef<Attribute>());
if (mapsProducer.empty() || mapsConsumer.empty() ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
@@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
numLhsDims, /*numSymbols =*/0, reassociations, context)));
reassociations.clear();
}
- return ArrayAttr::get(reassociationMaps, context);
+ return ArrayAttr::get(context, reassociationMaps);
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 8db4824cbbd2..c7b76404b2f8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
// wrong, so abort.
if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
return nullptr;
- return ArrayAttr::get(
- llvm::to_vector<4>(llvm::map_range(
- newIndexingMaps,
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
- context);
+ return ArrayAttr::get(context,
+ llvm::to_vector<4>(llvm::map_range(
+ newIndexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ })));
}
/// Modify the region of indexed generic op to drop arguments corresponding to
@@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
rewriter.startRootUpdate(op);
op.indexing_mapsAttr(newIndexingMapAttr);
- op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+ op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
(void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
rewriter.finalizeRootUpdate(op);
return success();
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
RankedTensorType::get(newShape, type.getElementType()),
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
newIndexExprs, context),
- ArrayAttr::get(reassociationMaps, context)};
+ ArrayAttr::get(context, reassociationMaps)};
return info;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index cac0ae0d081c..b893f2ba6721 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op,
applyPermutationToVector(itTypesVector, interchangeVector);
op->setAttr(getIndexingMapsAttrName(),
- ArrayAttr::get(newIndexingMaps, context));
+ ArrayAttr::get(context, newIndexingMaps));
op->setAttr(getIteratorTypesAttrName(),
- ArrayAttr::get(itTypesVector, context));
+ ArrayAttr::get(context, itTypesVector));
return op;
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 9b62b4289c77..4ce29b4a8397 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
});
for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get(
- cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+ funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
}
return success();
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0902b297ddd3..65ebc54aeeb3 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
return a;
}
// If this is reached, all inputs were statically known passing.
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
}
static LogicalResult verify(AssumingAllOp op) {
@@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// Both operands are not needed if one is a scalar.
if (operands[0] &&
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
if (operands[1] &&
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
if (operands[0] && operands[1]) {
auto lhsShape = llvm::to_vector<6>(
@@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
}
// Lastly, see if folding can be completed based on what constraints are known
@@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion
// failure, we do not replace it with a constant witness.
@@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
if (llvm::all_of(operands,
[&](Attribute a) { return a && a == operands[0]; }))
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion
// failure, we do not try to replace it with a constant witness. Similarly, we
@@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs())
- return BoolAttr::get(true, getContext());
+ return BoolAttr::get(getContext(), true);
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (lhs == nullptr)
return {};
auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
if (rhs == nullptr)
return {};
- return BoolAttr::get(lhs == rhs, getContext());
+ return BoolAttr::get(getContext(), lhs == rhs);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index c085c1cd33a7..ca2e2731df03 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
- return BoolAttr::get(val, getContext());
+ return BoolAttr::get(getContext(), val);
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return BoolAttr::get(val, getContext());
+ return BoolAttr::get(getContext(), val);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index f20b713e8e77..9fe8cf23c162 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
- auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
+ auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
p << op.rhs() << ", " << op.acc();
if (op.masks().size() == 2)
@@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
});
- return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
+ return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
}
static LogicalResult verify(InsertStridedSliceOp op) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8a5206eb0b1c..bafeccbd53ea 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
BoolAttr Builder::getBoolAttr(bool value) {
- return BoolAttr::get(value, context);
+ return BoolAttr::get(context, value);
}
DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
- return DictionaryAttr::get(value, context);
+ return DictionaryAttr::get(context, value);
}
IntegerAttr Builder::getIndexAttr(int64_t value) {
@@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
}
StringAttr Builder::getStringAttr(StringRef bytes) {
- return StringAttr::get(bytes, context);
+ return StringAttr::get(context, bytes);
}
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
- return ArrayAttr::get(value, context);
+ return ArrayAttr::get(context, value);
}
FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
@@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
return getSymbolRefAttr(symName.getValue());
}
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
- return SymbolRefAttr::get(value, getContext());
+ return SymbolRefAttr::get(getContext(), value);
}
SymbolRefAttr
Builder::getSymbolRefAttr(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
- return SymbolRefAttr::get(value, nestedReferences, getContext());
+ return SymbolRefAttr::get(getContext(), value, nestedReferences);
}
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 162bed96e3f4..58a5b3370364 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
// ArrayAttr
//===----------------------------------------------------------------------===//
-ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) {
return Base::get(context, value);
}
@@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
return findDuplicateElement(array);
}
-DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
- MLIRContext *context) {
+DictionaryAttr DictionaryAttr::get(MLIRContext *context,
+ ArrayRef<NamedAttribute> value) {
if (value.empty())
return DictionaryAttr::getEmpty(context);
assert(llvm::all_of(value,
@@ -267,13 +267,12 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
}
-SymbolRefAttr SymbolRefAttr::get(StringRef value,
- ArrayRef<FlatSymbolRefAttr> nestedReferences,
- MLIRContext *ctx) {
+SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences) {
return Base::get(ctx, value, nestedReferences);
}
@@ -294,7 +293,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
if (type.isSignlessInteger(1))
- return BoolAttr::get(value.getBoolValue(), type.getContext());
+ return BoolAttr::get(type.getContext(), value.getBoolValue());
return Base::get(type.getContext(), type, value);
}
@@ -377,8 +376,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
// OpaqueAttr
//===----------------------------------------------------------------------===//
-OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
- MLIRContext *context) {
+OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
+ StringRef attrData, Type type) {
return Base::get(context, dialect, attrData, type);
}
@@ -409,7 +408,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
// StringAttr
//===----------------------------------------------------------------------===//
-StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) {
return get(bytes, NoneType::get(context));
}
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 469aa310140c..db383c691c7c 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
newAttrs.insert(attr);
for (auto &attr : getAttrs())
newAttrs.insert(attr);
- dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext()));
+ dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
// Clone the body.
getBody().cloneInto(&dest.getBody(), mapper);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index dbfa1bdf6f7e..8d13a9c4af32 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
storage->setType(NoneType::get(ctx));
}
-BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
+BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
}
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b4fe9f854dda..be312689cebb 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name,
ArrayRef<NamedAttribute> attributes,
BlockRange successors, unsigned numRegions) {
return create(location, name, resultTypes, operands,
- DictionaryAttr::get(attributes, location.getContext()),
+ DictionaryAttr::get(location.getContext(), attributes),
successors, numRegions);
}
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index b198600e9242..70133d22482f 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
MLIRContext *ctx = symbol->getContext();
- auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
+ auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
results.push_back(leafRef);
// Early exit for when 'within' is the parent of 'symbol'.
@@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
getNameIfSymbol(symbolTableOp, symbolNameId);
if (!symbolTableName)
return failure();
- results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
+ results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
symbolTableOp = symbolTableOp->getParentOp();
if (symbolTableOp == within)
break;
nestedRefs.insert(nestedRefs.begin(),
- FlatSymbolRefAttr::get(*symbolTableName, ctx));
+ FlatSymbolRefAttr::get(ctx, *symbolTableName));
} while (true);
return success();
}
@@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) {
/// Sets the name of the given symbol operation.
void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
symbol->setAttr(getSymbolAttrName(),
- StringAttr::get(name, symbol->getContext()));
+ StringAttr::get(symbol->getContext(), name));
}
/// Returns the visibility of the given symbol operation.
@@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
"unknown symbol visibility kind");
StringRef visName = vis == Visibility::Private ? "private" : "nested";
- symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
+ symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
}
/// Returns the nearest symbol table from a given operation `from`. Returns
@@ -603,7 +603,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
// doesn't support parent references.
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
symbol->getParentOp())
- return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
+ return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
return {};
}
@@ -659,7 +659,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
template <typename IRUnit>
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
IRUnit *limit) {
- return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
+ return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
}
/// Returns true if the given reference 'SubRef' is a sub reference of the
@@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW(
if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
updateAttrs(make_second_range(newAttrs));
- return DictionaryAttr::get(newAttrs, dictAttr.getContext());
+ return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
}
auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
updateAttrs(newAttrs);
- return ArrayAttr::get(newAttrs, container.getContext());
+ return ArrayAttr::get(container.getContext(), newAttrs);
}
/// Generates a new symbol reference attribute with a new leaf reference.
@@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
return newLeafAttr;
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
nestedRefs.back() = newLeafAttr;
- return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
- oldAttr.getContext());
+ return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
+ nestedRefs);
}
/// The implementation of SymbolTable::replaceAllSymbolUses below.
@@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
// Generate a new attribute to replace the given attribute.
MLIRContext *ctx = limit->getContext();
- FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
+ FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
@@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
if (useRef != scope.symbol) {
if (scope.symbol.isa<FlatSymbolRefAttr>()) {
replacementRef =
- SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
+ SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
} else {
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
newLeafAttr;
replacementRef =
- SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
+ SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
}
}
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 859e8e279917..98f74174e5a3 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) {
return Attribute();
return type ? StringAttr::get(val, type)
- : StringAttr::get(val, getContext());
+ : StringAttr::get(getContext(), val);
}
// Parse a symbol reference attribute.
@@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) {
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
- nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+ nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
}
return builder.getSymbolRefAttr(nameStr, nestedRefs);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2f0b3379d152..52ce37eb79ab 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -742,7 +742,8 @@ void OpEmitter::genAttrGetters() {
body << " ::mlir::MLIRContext* ctx = getContext();\n";
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
- body << " return ::mlir::DictionaryAttr::get({\n";
+ body << " return ::mlir::DictionaryAttr::get(";
+ body << " ctx, {\n";
interleave(
derivedAttrs, body,
[&](const NamedAttribute &namedAttr) {
@@ -755,7 +756,7 @@ void OpEmitter::genAttrGetters() {
<< "}";
},
",\n");
- body << "\n }, ctx);";
+ body << "});";
}
}
}
diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp
index 5595986e4016..52f522387017 100644
--- a/mlir/tools/mlir-tblgen/StructsGen.cpp
+++ b/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName,
}
const char *getEndInfo = R"(
- ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context);
+ ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
return dict.dyn_cast<{0}>();
}
)";
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index 0dd9ef9de3e6..ef0bdd81ee3a 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) {
newValues.push_back(wrongAttr);
// Make a new DictionaryAttr and validate.
- auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+ auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) {
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
newValues.push_back(wrongAttr);
- auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+ auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
newValues.push_back(wrongAttr);
- auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+ auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) {
expectedValues.begin() + 1, expectedValues.end());
// Make a new DictionaryAttr and validate it is not a validate TestStruct.
- auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+ auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
More information about the flang-commits
mailing list