[Mlir-commits] [mlir] 511dd4f - Revert "Reorder MLIRContext location in BuiltinAttributes.h"
Tres Popp
llvmlistbot at llvm.org
Mon Feb 8 00:33:34 PST 2021
Author: Tres Popp
Date: 2021-02-08T09:32:42+01:00
New Revision: 511dd4f4383b1c2873beac4dbea2df302f1f9d0c
URL: https://github.com/llvm/llvm-project/commit/511dd4f4383b1c2873beac4dbea2df302f1f9d0c
DIFF: https://github.com/llvm/llvm-project/commit/511dd4f4383b1c2873beac4dbea2df302f1f9d0c.diff
LOG: Revert "Reorder MLIRContext location in BuiltinAttributes.h"
This reverts commit 7827753f9810e846fb702f3e8dcff0bfb37344e1.
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Lower/FIRBuilder.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/StructsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index cde53725b4a4..8f3670b29d74 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -267,27 +267,27 @@ class fir_AllocatableOp<string mnemonic, list<OpTrait> traits = []> :
static constexpr llvm::StringRef inType() { return "in_type"; }
static constexpr llvm::StringRef lenpName() { return "len_param_count"; }
mlir::Type getAllocatedType();
-
+
bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; }
-
+
unsigned numLenParams() {
if (auto val = (*this)->getAttrOfType<mlir::IntegerAttr>(lenpName()))
return val.getInt();
return 0;
}
-
+
operand_range getLenParams() {
return {operand_begin(), operand_begin() + numLenParams()};
}
-
+
unsigned numShapeOperands() {
return operand_end() - operand_begin() + numLenParams();
}
-
+
operand_range getShapeOperands() {
return {operand_begin() + numLenParams(), operand_end()};
}
-
+
static mlir::Type getRefTy(mlir::Type ty);
/// Get the input type of the allocation
@@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> {
}];
let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len);
-
+
let results = (outs fir_BoxCharType);
let assemblyFormat = [{
@@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
p.printFunctionalType((*this)->getOperandTypes(),
(*this)->getResultTypes());
}];
-
+
let verifier = [{
auto refTy = ref().getType();
if (fir::isa_ref_type(refTy)) {
@@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
-
+
let extraClassDeclaration = [{
static constexpr llvm::StringRef baseType() { return "base_type"; }
mlir::Type getBaseType();
@@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
let printer = [{
p << getOperationName() << ' '
- << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
+ << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
<< ", " << (*this)->getAttr(typeAttrName());
if (getNumOperands()) {
p << '(';
@@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
CArg<"ValueRange", "llvm::None">:$iterArgs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
-
+
let extraClassDeclaration = [{
mlir::Block *getBody() { return ®ion().front(); }
mlir::Value getIterateVar() { return getBody()->getArgument(1); }
@@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> {
}];
let arguments = (ins FirRealAttr:$constant);
-
+
let results = (outs fir_RealType:$res);
let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)";
-
+
let verifier = [{
if (!getType().isa<fir::RealType>())
return emitOpError("must be a !fir.real type");
@@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
}];
let results = (outs fir_ComplexType);
-
+
let parser = [{
fir::RealAttr realp;
fir::RealAttr imagp;
@@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc",
def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
let summary = "convert a symbol to an SSA value";
-
+
let description = [{
Convert a symbol (a function or global reference) to an SSA-value to be
used in other Operations.
@@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {
let summary = "encapsulates all Fortran scalar type conversions";
-
+
let description = [{
Generalized type conversion. Convert the ssa value from type T to type U.
Not all pairs of types have conversions. When types T and U are the same
@@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
mlir::Type resultType() {
return fir::AllocaOp::wrapResultType(getType());
}
-
+
/// Return the initializer attribute if it exists, or a null attribute.
Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); }
@@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
}
mlir::FlatSymbolRefAttr getSymbol() {
- return mlir::FlatSymbolRefAttr::get(getContext(),
+ return mlir::FlatSymbolRefAttr::get(
(*this)->getAttrOfType<mlir::StringAttr>(
- mlir::SymbolTable::getSymbolAttrName()).getValue());
+ mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext());
}
}];
}
@@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> {
}];
let printer = [{
- p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
+ p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
<< ", " << (*this)->getAttr(intAttrName());
}];
diff --git a/flang/lib/Lower/FIRBuilder.cpp b/flang/lib/Lower/FIRBuilder.cpp
index 0a8473b73268..3f470d61c286 100644
--- a/flang/lib/Lower/FIRBuilder.cpp
+++ b/flang/lib/Lower/FIRBuilder.cpp
@@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc,
fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit(
mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) {
- auto strAttr = mlir::StringAttr::get(getContext(), data);
+ auto strAttr = mlir::StringAttr::get(data, getContext());
auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext());
mlir::NamedAttribute dataAttr(valTag, strAttr);
auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext());
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 8523a8371192..3883ce2ed0c8 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -107,7 +107,7 @@ class PrintOpLowering : public ConversionPattern {
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
- return SymbolRefAttr::get(context, "printf");
+ return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ class PrintOpLowering : public ConversionPattern {
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
- return SymbolRefAttr::get(context, "printf");
+ return SymbolRefAttr::get("printf", context);
}
/// Return a value representing an access into a global string with the given
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 8523a8371192..3883ce2ed0c8 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -107,7 +107,7 @@ class PrintOpLowering : public ConversionPattern {
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
- return SymbolRefAttr::get(context, "printf");
+ return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ class PrintOpLowering : public ConversionPattern {
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
- return SymbolRefAttr::get(context, "printf");
+ return SymbolRefAttr::get("printf", context);
}
/// Return a value representing an access into a global string with the given
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index b903c0928d1b..794417e99652 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
- auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+ auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
return indexingMaps == maps;
}
@@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
- auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+ auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
return indexingMaps == maps;
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 571c9126f163..34e7e8cfce12 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -69,7 +69,7 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
using Base::Base;
using ValueType = ArrayRef<Attribute>;
- static ArrayAttr get(MLIRContext *context, ArrayRef<Attribute> value);
+ static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
ArrayRef<Attribute> getValue() const;
Attribute operator[](unsigned idx) const;
@@ -126,8 +126,8 @@ class DictionaryAttr
/// attributes. This method assumes that the provided list is unordered. If
/// the caller can guarantee that the attributes are ordered by name,
/// getWithSorted should be used instead.
- static DictionaryAttr get(MLIRContext *context,
- ArrayRef<NamedAttribute> value);
+ static DictionaryAttr get(ArrayRef<NamedAttribute> value,
+ MLIRContext *context);
/// Construct a dictionary with an array of values that is known to already be
/// sorted by name and uniqued.
@@ -250,7 +250,7 @@ class BoolAttr : public Attribute {
using Attribute::Attribute;
using ValueType = bool;
- static BoolAttr get(MLIRContext *context, bool value);
+ static BoolAttr get(bool value, MLIRContext *context);
/// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
/// avoid bringing in all of IntegerAttrs methods.
@@ -292,8 +292,8 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
using Base::Base;
/// Get or create a new OpaqueAttr with the provided dialect and string data.
- static OpaqueAttr get(MLIRContext *context, Identifier dialect,
- StringRef attrData, Type type);
+ static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
+ MLIRContext *context);
/// Get or create a new OpaqueAttr with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
@@ -325,7 +325,7 @@ class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
using ValueType = StringRef;
/// Get an instance of a StringAttr with the given string.
- static StringAttr get(MLIRContext *context, StringRef bytes);
+ static StringAttr get(StringRef bytes, MLIRContext *context);
/// Get an instance of a StringAttr with the given string and Type.
static StringAttr get(StringRef bytes, Type type);
@@ -348,12 +348,13 @@ class SymbolRefAttr
using Base::Base;
/// Construct a symbol reference for the given value name.
- static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
+ static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
/// Construct a symbol reference for the given value name, and a set of nested
/// references that are further resolve to a nested symbol.
- static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
- ArrayRef<FlatSymbolRefAttr> references);
+ static SymbolRefAttr get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> references,
+ MLIRContext *ctx);
/// Returns the name of the top level symbol reference, i.e. the root of the
/// reference path.
@@ -376,8 +377,8 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
using ValueType = StringRef;
/// Construct a symbol reference for the given value name.
- static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
- return SymbolRefAttr::get(ctx, value);
+ static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
+ return SymbolRefAttr::get(value, ctx);
}
/// Returns the name of the held symbol reference.
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 588a5f7ed62c..be8a68979203 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -569,7 +569,7 @@ void FunctionLike<ConcreteType>::setArgAttrs(
if (attributes.empty())
return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
Operation *op = this->getOperation();
- op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
+ op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
}
template <typename ConcreteType>
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 70cd55dbbb13..45b9c490fd21 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -315,7 +315,7 @@ class alignas(8) Operation final
attrs = newAttrs;
}
void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
- setAttrs(DictionaryAttr::get(getContext(), newAttrs));
+ setAttrs(DictionaryAttr::get(newAttrs, getContext()));
}
/// Return the specified attribute if present, null otherwise.
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index a7b1fd8cfe64..c5f252e45a20 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
/*defaultImplementation=*/[{
this->getOperation()->setAttr(
mlir::SymbolTable::getSymbolAttrName(),
- StringAttr::get(this->getOperation()->getContext(), name));
+ StringAttr::get(name, this->getOperation()->getContext()));
}]
>,
InterfaceMethod<"Gets the visibility of this symbol.",
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 9e61e3a9d6e0..90ed9cb0ad02 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) {
MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
MlirAttribute const *elements) {
SmallVector<Attribute, 8> attrs;
- return wrap(
- ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
- elements, attrs)));
+ return wrap(ArrayAttr::get(
+ unwrapList(static_cast<size_t>(numElements), elements, attrs),
+ unwrap(ctx)));
}
intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
@@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
attributes.emplace_back(
Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
unwrap(elements[i].attribute));
- return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
+ return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
}
intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
@@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) {
}
MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
- return wrap(BoolAttr::get(unwrap(ctx), value));
+ return wrap(BoolAttr::get(value, unwrap(ctx)));
}
bool mlirBoolAttrGetValue(MlirAttribute attr) {
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type) {
- return wrap(OpaqueAttr::get(
- unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
- StringRef(data, dataLength), unwrap(type)));
+ return wrap(
+ OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
+ StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
}
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
@@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
- return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
+ return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
@@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
- return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
+ return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
}
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
@@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
}
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
- return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
+ return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
}
MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
index 1b9e36180114..447b00567776 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
@@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
auto blob = convertModuleToBlob(llvmModule, loc, name);
if (!blob)
return {};
- return StringAttr::get(loc->getContext(), {blob->data(), blob->size()});
+ return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 5b62ca455dea..887d3e798af7 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
// Set SPIR-V binary shader data as an attribute.
vulkanLaunchCallOp->setAttr(
kSPIRVBlobAttrName,
- StringAttr::get(loc->getContext(), {binary.data(), binary.size()}));
+ StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
// Set entry point name as an attribute.
vulkanLaunchCallOp->setAttr(
kSPIRVEntryPointAttrName,
- StringAttr::get(loc->getContext(), launchOp.getKernelName()));
+ StringAttr::get(launchOp.getKernelName(), loc->getContext()));
launchOp.erase();
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 29cf42205a56..87026e4483e6 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -687,8 +687,8 @@ class ExecutionModePattern
rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, executionMode,
- ArrayAttr::get(context,
- {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
+ ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
+ context));
// Insert extra operands if they exist into execution mode info struct.
for (unsigned i = 0, e = values.size(); i < e; ++i) {
@@ -696,9 +696,9 @@ class ExecutionModePattern
Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, entry,
- ArrayAttr::get(context,
- {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
- rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
+ ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
+ context));
}
rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
rewriter.eraseOp(op);
@@ -1297,17 +1297,17 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
switch (funcOp.function_control()) {
#define DISPATCH(functionControl, llvmAttr) \
case functionControl: \
- newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
+ newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
break;
DISPATCH(spirv::FunctionControl::Inline,
- StringAttr::get(context, "alwaysinline"));
+ StringAttr::get("alwaysinline", context));
DISPATCH(spirv::FunctionControl::DontInline,
- StringAttr::get(context, "noinline"));
+ StringAttr::get("noinline", context));
DISPATCH(spirv::FunctionControl::Pure,
- StringAttr::get(context, "readonly"));
+ StringAttr::get("readonly", context));
DISPATCH(spirv::FunctionControl::Const,
- StringAttr::get(context, "readnone"));
+ StringAttr::get("readnone", context));
#undef DISPATCH
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index ea0a4259637c..794f4a5d6c1e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
- StringAttr::get(m.getContext(), this->dataLayout));
+ StringAttr::get(this->dataLayout, m.getContext()));
}
};
} // end namespace
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 683de815a54e..9e88250e2cab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -762,7 +762,7 @@ class VectorExtractOpConversion
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
- ArrayAttr::get(context, positionAttrs.drop_back());
+ ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
@@ -871,7 +871,7 @@ class VectorInsertOpConversion
if (positionAttrs.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
auto nMinusOnePositionAttrs =
- ArrayAttr::get(context, positionAttrs.drop_back());
+ ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
@@ -887,7 +887,7 @@ class VectorInsertOpConversion
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
- ArrayAttr::get(context, positionAttrs.drop_back());
+ ArrayAttr::get(positionAttrs.drop_back(), context);
inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
adaptor.dest(), inserted,
nMinusOnePositionAttrs);
diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
index 6ccb59aff35a..c1d0820e1cc7 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
@@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp,
}
ArrayRef<Attribute> mappingAsAttrs(mapping.data(), mapping.size());
ploopOp->setAttr(getMappingAttrName(),
- ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs));
+ ArrayAttr::get(mappingAsAttrs, ploopOp.getContext()));
return success();
}
} // namespace gpu
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e96668779401..a3960ae94b27 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
if (genericAttrNamesSet.count(attr.first.strref()) > 0)
genericAttrs.push_back(attr);
if (!genericAttrs.empty()) {
- auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
+ auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
p << genericDictAttr;
}
@@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
// Handle the corner case of the result being a rank 0 shaped type. Return an
// emtpy ArrayAttr.
if (mapsConsumer.empty() && !mapsProducer.empty())
- return ArrayAttr::get(context, ArrayRef<Attribute>());
+ return ArrayAttr::get(ArrayRef<Attribute>(), context);
if (mapsProducer.empty() || mapsConsumer.empty() ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
@@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
numLhsDims, /*numSymbols =*/0, reassociations, context)));
reassociations.clear();
}
- return ArrayAttr::get(context, reassociationMaps);
+ return ArrayAttr::get(reassociationMaps, context);
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c7b76404b2f8..8db4824cbbd2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
// wrong, so abort.
if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
return nullptr;
- return ArrayAttr::get(context,
- llvm::to_vector<4>(llvm::map_range(
- newIndexingMaps, [](AffineMap map) -> Attribute {
- return AffineMapAttr::get(map);
- })));
+ return ArrayAttr::get(
+ llvm::to_vector<4>(llvm::map_range(
+ newIndexingMaps,
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
+ context);
}
/// Modify the region of indexed generic op to drop arguments corresponding to
@@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
rewriter.startRootUpdate(op);
op.indexing_mapsAttr(newIndexingMapAttr);
- op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
+ op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
(void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
rewriter.finalizeRootUpdate(op);
return success();
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
RankedTensorType::get(newShape, type.getElementType()),
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
newIndexExprs, context),
- ArrayAttr::get(context, reassociationMaps)};
+ ArrayAttr::get(reassociationMaps, context)};
return info;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index b893f2ba6721..cac0ae0d081c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op,
applyPermutationToVector(itTypesVector, interchangeVector);
op->setAttr(getIndexingMapsAttrName(),
- ArrayAttr::get(context, newIndexingMaps));
+ ArrayAttr::get(newIndexingMaps, context));
op->setAttr(getIteratorTypesAttrName(),
- ArrayAttr::get(context, itTypesVector));
+ ArrayAttr::get(itTypesVector, context));
return op;
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 4ce29b4a8397..9b62b4289c77 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
});
for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get(
- funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
+ cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
}
return success();
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65ebc54aeeb3..0902b297ddd3 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
return a;
}
// If this is reached, all inputs were statically known passing.
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
}
static LogicalResult verify(AssumingAllOp op) {
@@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// Both operands are not needed if one is a scalar.
if (operands[0] &&
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
if (operands[1] &&
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
if (operands[0] && operands[1]) {
auto lhsShape = llvm::to_vector<6>(
@@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
}
// Lastly, see if folding can be completed based on what constraints are known
@@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
// Because a failing witness result here represents an eventual assertion
// failure, we do not replace it with a constant witness.
@@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
if (llvm::all_of(operands,
[&](Attribute a) { return a && a == operands[0]; }))
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
// Because a failing witness result here represents an eventual assertion
// failure, we do not try to replace it with a constant witness. Similarly, we
@@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs())
- return BoolAttr::get(getContext(), true);
+ return BoolAttr::get(true, getContext());
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (lhs == nullptr)
return {};
auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
if (rhs == nullptr)
return {};
- return BoolAttr::get(getContext(), lhs == rhs);
+ return BoolAttr::get(lhs == rhs, getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ca2e2731df03..c085c1cd33a7 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
- return BoolAttr::get(getContext(), val);
+ return BoolAttr::get(val, getContext());
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return BoolAttr::get(getContext(), val);
+ return BoolAttr::get(val, getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9fe8cf23c162..f20b713e8e77 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
- auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
+ auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
p << op.rhs() << ", " << op.acc();
if (op.masks().size() == 2)
@@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
});
- return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
+ return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
}
static LogicalResult verify(InsertStridedSliceOp op) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index bafeccbd53ea..8a5206eb0b1c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
BoolAttr Builder::getBoolAttr(bool value) {
- return BoolAttr::get(context, value);
+ return BoolAttr::get(value, context);
}
DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
- return DictionaryAttr::get(context, value);
+ return DictionaryAttr::get(value, context);
}
IntegerAttr Builder::getIndexAttr(int64_t value) {
@@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
}
StringAttr Builder::getStringAttr(StringRef bytes) {
- return StringAttr::get(context, bytes);
+ return StringAttr::get(bytes, context);
}
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
- return ArrayAttr::get(context, value);
+ return ArrayAttr::get(value, context);
}
FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
@@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
return getSymbolRefAttr(symName.getValue());
}
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
- return SymbolRefAttr::get(getContext(), value);
+ return SymbolRefAttr::get(value, getContext());
}
SymbolRefAttr
Builder::getSymbolRefAttr(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
- return SymbolRefAttr::get(getContext(), value, nestedReferences);
+ return SymbolRefAttr::get(value, nestedReferences, getContext());
}
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 58a5b3370364..162bed96e3f4 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
// ArrayAttr
//===----------------------------------------------------------------------===//
-ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) {
+ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
return Base::get(context, value);
}
@@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
return findDuplicateElement(array);
}
-DictionaryAttr DictionaryAttr::get(MLIRContext *context,
- ArrayRef<NamedAttribute> value) {
+DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
+ MLIRContext *context) {
if (value.empty())
return DictionaryAttr::getEmpty(context);
assert(llvm::all_of(value,
@@ -267,12 +267,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
}
-SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
- ArrayRef<FlatSymbolRefAttr> nestedReferences) {
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences,
+ MLIRContext *ctx) {
return Base::get(ctx, value, nestedReferences);
}
@@ -293,7 +294,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
if (type.isSignlessInteger(1))
- return BoolAttr::get(type.getContext(), value.getBoolValue());
+ return BoolAttr::get(value.getBoolValue(), type.getContext());
return Base::get(type.getContext(), type, value);
}
@@ -376,8 +377,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
// OpaqueAttr
//===----------------------------------------------------------------------===//
-OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
- StringRef attrData, Type type) {
+OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
+ MLIRContext *context) {
return Base::get(context, dialect, attrData, type);
}
@@ -408,7 +409,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
// StringAttr
//===----------------------------------------------------------------------===//
-StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) {
+StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
return get(bytes, NoneType::get(context));
}
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index db383c691c7c..469aa310140c 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
newAttrs.insert(attr);
for (auto &attr : getAttrs())
newAttrs.insert(attr);
- dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
+ dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext()));
// Clone the body.
getBody().cloneInto(&dest.getBody(), mapper);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 8d13a9c4af32..dbfa1bdf6f7e 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
storage->setType(NoneType::get(ctx));
}
-BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
+BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
}
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index be312689cebb..b4fe9f854dda 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name,
ArrayRef<NamedAttribute> attributes,
BlockRange successors, unsigned numRegions) {
return create(location, name, resultTypes, operands,
- DictionaryAttr::get(location.getContext(), attributes),
+ DictionaryAttr::get(attributes, location.getContext()),
successors, numRegions);
}
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 70133d22482f..b198600e9242 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
MLIRContext *ctx = symbol->getContext();
- auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
+ auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
results.push_back(leafRef);
// Early exit for when 'within' is the parent of 'symbol'.
@@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
getNameIfSymbol(symbolTableOp, symbolNameId);
if (!symbolTableName)
return failure();
- results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
+ results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
symbolTableOp = symbolTableOp->getParentOp();
if (symbolTableOp == within)
break;
nestedRefs.insert(nestedRefs.begin(),
- FlatSymbolRefAttr::get(ctx, *symbolTableName));
+ FlatSymbolRefAttr::get(*symbolTableName, ctx));
} while (true);
return success();
}
@@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) {
/// Sets the name of the given symbol operation.
void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
symbol->setAttr(getSymbolAttrName(),
- StringAttr::get(symbol->getContext(), name));
+ StringAttr::get(name, symbol->getContext()));
}
/// Returns the visibility of the given symbol operation.
@@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
"unknown symbol visibility kind");
StringRef visName = vis == Visibility::Private ? "private" : "nested";
- symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
+ symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
}
/// Returns the nearest symbol table from a given operation `from`. Returns
@@ -603,7 +603,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
// doesn't support parent references.
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
symbol->getParentOp())
- return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
+ return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
return {};
}
@@ -659,7 +659,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
template <typename IRUnit>
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
IRUnit *limit) {
- return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
+ return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
}
/// Returns true if the given reference 'SubRef' is a sub reference of the
@@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW(
if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
updateAttrs(make_second_range(newAttrs));
- return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
+ return DictionaryAttr::get(newAttrs, dictAttr.getContext());
}
auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
updateAttrs(newAttrs);
- return ArrayAttr::get(container.getContext(), newAttrs);
+ return ArrayAttr::get(newAttrs, container.getContext());
}
/// Generates a new symbol reference attribute with a new leaf reference.
@@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
return newLeafAttr;
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
nestedRefs.back() = newLeafAttr;
- return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
- nestedRefs);
+ return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
+ oldAttr.getContext());
}
/// The implementation of SymbolTable::replaceAllSymbolUses below.
@@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
// Generate a new attribute to replace the given attribute.
MLIRContext *ctx = limit->getContext();
- FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
+ FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
@@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
if (useRef != scope.symbol) {
if (scope.symbol.isa<FlatSymbolRefAttr>()) {
replacementRef =
- SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
+ SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
} else {
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
newLeafAttr;
replacementRef =
- SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
+ SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
}
}
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 98f74174e5a3..859e8e279917 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) {
return Attribute();
return type ? StringAttr::get(val, type)
- : StringAttr::get(getContext(), val);
+ : StringAttr::get(val, getContext());
}
// Parse a symbol reference attribute.
@@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) {
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
- nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
+ nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
}
return builder.getSymbolRefAttr(nameStr, nestedRefs);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 52ce37eb79ab..2f0b3379d152 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -742,8 +742,7 @@ void OpEmitter::genAttrGetters() {
body << " ::mlir::MLIRContext* ctx = getContext();\n";
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
- body << " return ::mlir::DictionaryAttr::get(";
- body << " ctx, {\n";
+ body << " return ::mlir::DictionaryAttr::get({\n";
interleave(
derivedAttrs, body,
[&](const NamedAttribute &namedAttr) {
@@ -756,7 +755,7 @@ void OpEmitter::genAttrGetters() {
<< "}";
},
",\n");
- body << "});";
+ body << "\n }, ctx);";
}
}
}
diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp
index 52f522387017..5595986e4016 100644
--- a/mlir/tools/mlir-tblgen/StructsGen.cpp
+++ b/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName,
}
const char *getEndInfo = R"(
- ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
+ ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context);
return dict.dyn_cast<{0}>();
}
)";
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index ef0bdd81ee3a..0dd9ef9de3e6 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) {
newValues.push_back(wrongAttr);
// Make a new DictionaryAttr and validate.
- auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
+ auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) {
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
newValues.push_back(wrongAttr);
- auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
+ auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
newValues.push_back(wrongAttr);
- auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
+ auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) {
expectedValues.begin() + 1, expectedValues.end());
// Make a new DictionaryAttr and validate it is not a validate TestStruct.
- auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
+ auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
More information about the Mlir-commits
mailing list