[Mlir-commits] [mlir] cdfeeb8 - [mlir:ODS] Generate unwrapped operation attribute setters
River Riddle
llvmlistbot at llvm.org
Fri Oct 14 16:04:35 PDT 2022
Author: River Riddle
Date: 2022-10-14T15:57:51-07:00
New Revision: cdfeeb8a4058130d8ce59300867e272642c97dfa
URL: https://github.com/llvm/llvm-project/commit/cdfeeb8a4058130d8ce59300867e272642c97dfa
DIFF: https://github.com/llvm/llvm-project/commit/cdfeeb8a4058130d8ce59300867e272642c97dfa.diff
LOG: [mlir:ODS] Generate unwrapped operation attribute setters
This allows for setting an attribute using the underlying C++ type,
which is generally much nicer to interact with than the attribute type.
Differential Revision: https://reviews.llvm.org/D135838
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/mlir-tblgen/op-attribute.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index df0ce36829c4c..bb505c52168d6 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -718,7 +718,6 @@ def AffineParallelOp : Affine_Op<"parallel",
/// Sets elements of the loop lower bound.
void setLowerBounds(ValueRange operands, AffineMap map);
- void setLowerBoundsMap(AffineMap map);
/// Returns elements of the loop upper bound.
AffineMap getUpperBoundMap(unsigned pos);
@@ -727,7 +726,6 @@ def AffineParallelOp : Affine_Op<"parallel",
/// Sets elements fo the loop upper bound.
void setUpperBounds(ValueRange operands, AffineMap map);
- void setUpperBoundsMap(AffineMap map);
void setSteps(ArrayRef<int64_t> newSteps);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f200135847b8e..1d05601f4fbea 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3579,22 +3579,6 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
setUpperBoundsMapAttr(AffineMapAttr::get(map));
}
-void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
- AffineMap lbMap = getLowerBoundsMap();
- assert(lbMap.getNumDims() == map.getNumDims() &&
- lbMap.getNumSymbols() == map.getNumSymbols());
- (void)lbMap;
- setLowerBoundsMapAttr(AffineMapAttr::get(map));
-}
-
-void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
- AffineMap ubMap = getUpperBoundsMap();
- assert(ubMap.getNumDims() == map.getNumDims() &&
- ubMap.getNumSymbols() == map.getNumSymbols());
- (void)ubMap;
- setUpperBoundsMapAttr(AffineMapAttr::get(map));
-}
-
void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0378f5f9a1a74..d1d03a549092d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1481,7 +1481,7 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
Pred origPred = getPredicate();
for (auto pred : invPreds) {
if (origPred == pred.first) {
- setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
+ setPredicate(pred.second);
Value lhs = getLhs();
Value rhs = getRhs();
getLhsMutable().assign(rhs);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 365febf32ef2a..3e9c235a53633 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2551,8 +2551,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
dynamicIndices);
getDynamicIndicesMutable().assign(dynamicIndices);
- setRawConstantIndicesAttr(
- DenseI32ArrayAttr::get(getContext(), rawConstantIndices));
+ setRawConstantIndices(rawConstantIndices);
return Value{*this};
}
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 302afdcf14ea4..0070000cb5f8b 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -640,10 +640,9 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) {
b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
}
if (gv->hasAtLeastLocalUnnamedAddr())
- op.setUnnamedAddrAttr(UnnamedAddrAttr::get(
- context, convertUnnamedAddrFromLLVM(gv->getUnnamedAddr())));
+ op.setUnnamedAddr(convertUnnamedAddrFromLLVM(gv->getUnnamedAddr()));
if (gv->hasSection())
- op.setSectionAttr(b.getStringAttr(gv->getSection()));
+ op.setSection(gv->getSection());
return globals[gv] = op;
}
@@ -1046,13 +1045,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
}
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
- fop->setAttr(b.getStringAttr("personality"), personality);
+ fop.setPersonalityAttr(personality);
else if (f->hasPersonalityFn())
emitWarning(UnknownLoc::get(context),
"could not deduce personality, skipping it");
if (f->hasGC())
- fop.setGarbageCollectorAttr(b.getStringAttr(f->getGC()));
+ fop.setGarbageCollector(StringRef(f->getGC()));
// Handle Function attributes.
processFunctionAttributes(f, fop);
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index e6cc49dfb0495..7e7a76227f37e 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -127,10 +127,18 @@ def AOp : NS_Op<"a_op", []> {
// DEF: void AOp::setAAttrAttr(some-attr-kind attr) {
// DEF-NEXT: (*this)->setAttr(getAAttrAttrName(), attr);
+// DEF: void AOp::setAAttr(some-return-type attrValue) {
+// DEF-NEXT: (*this)->setAttr(getAAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), attrValue));
// DEF: void AOp::setBAttrAttr(some-attr-kind attr) {
// DEF-NEXT: (*this)->setAttr(getBAttrAttrName(), attr);
+// DEF: void AOp::setBAttr(some-return-type attrValue) {
+// DEF-NEXT: (*this)->setAttr(getBAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), attrValue));
// DEF: void AOp::setCAttrAttr(some-attr-kind attr) {
// DEF-NEXT: (*this)->setAttr(getCAttrAttrName(), attr);
+// DEF: void AOp::setCAttr(::llvm::Optional<some-return-type> attrValue) {
+// DEF-NEXT: if (attrValue)
+// DEF-NEXT: return (*this)->setAttr(getCAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), *attrValue));
+// DEF-NEXT: (*this)->removeAttr(getCAttrAttrName());
// Test remove methods
// ---
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6304f74ea8e91..5b3d0ad18eec5 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -188,6 +188,22 @@ static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
!attr.getConstBuilderTemplate().empty();
}
+/// Build an attribute from a parameter value using the constant builder.
+static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
+ FmtContext &fctx,
+ StringRef paramName) {
+ std::string builderTemplate = attr.getConstBuilderTemplate().str();
+
+ // For StringAttr, its constant builder call will wrap the input in
+ // quotes, which is correct for normal string literals, but incorrect
+ // here given we use function arguments. So we need to strip the
+ // wrapping quotes.
+ if (StringRef(builderTemplate).contains("\"$0\""))
+ builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
+
+ return tgfmt(builderTemplate, &fctx, paramName).str();
+}
+
namespace {
/// Metadata on a registered attribute. Given that attributes are stored in
/// sorted order on operations, we can use information from ODS to deduce the
@@ -1092,13 +1108,69 @@ void OpEmitter::genAttrSetters() {
getterName);
};
+ // Generate a setter that accepts the underlying C++ type as opposed to the
+ // attribute type.
+ auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
+ Attribute attr) {
+ Attribute baseAttr = attr.getBaseAttr();
+ if (!canUseUnwrappedRawValue(baseAttr))
+ return;
+ FmtContext fctx;
+ fctx.withBuilder("::mlir::Builder(getContext())");
+ bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
+ bool isOptional = attr.isOptional();
+
+ auto createMethod = [&](const Twine ¶mType) {
+ return opClass.addMethod("void", setterName,
+ MethodParameter(paramType.str(), "attrValue"));
+ };
+
+ // Build the method using the correct parameter type depending on
+ // optionality.
+ Method *method = nullptr;
+ if (isUnitAttr)
+ method = createMethod("bool");
+ else if (isOptional)
+ method =
+ createMethod("::llvm::Optional<" + baseAttr.getReturnType() + ">");
+ else
+ method = createMethod(attr.getReturnType());
+ if (!method)
+ return;
+
+ // If the value isn't optional, just set it directly.
+ if (!isOptional) {
+ method->body() << formatv(
+ " (*this)->setAttr({0}AttrName(), {1});", getterName,
+ constBuildAttrFromParam(attr, fctx, "attrValue"));
+ return;
+ }
+
+ // Otherwise, we only set if the provided value is valid. If it isn't, we
+ // remove the attribute.
+
+ // TODO: Handle unit attr parameters specially, given that it is treated as
+ // optional but not in the same way as the others (i.e. it uses bool over
+ // Optional<>).
+ StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
+ const char *optionalCodeBody = R"(
+ if (attrValue)
+ return (*this)->setAttr({0}AttrName(), {1});
+ (*this)->removeAttr({0}AttrName());)";
+ method->body() << formatv(
+ optionalCodeBody, getterName,
+ constBuildAttrFromParam(baseAttr, fctx, paramStr));
+ };
+
for (const NamedAttribute &namedAttr : op.getAttributes()) {
if (namedAttr.attr.isDerivedAttr())
continue;
- for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
- op.getGetterNames(namedAttr.name)))
- emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
- namedAttr.attr);
+ for (auto [setterName, getterName] :
+ llvm::zip(op.getSetterNames(namedAttr.name),
+ op.getGetterNames(namedAttr.name))) {
+ emitAttrWithStorageType(setterName, getterName, namedAttr.attr);
+ emitAttrWithReturnType(setterName, getterName, namedAttr.attr);
+ }
}
}
@@ -2160,20 +2232,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
-
- std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
-
- // For StringAttr, its constant builder call will wrap the input in
- // quotes, which is correct for normal string literals, but incorrect
- // here given we use function arguments. So we need to strip the
- // wrapping quotes.
- if (StringRef(builderTemplate).contains("\"$0\""))
- builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
-
- std::string value =
- std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
- builderOpState, op.getGetterName(namedAttr.name), value);
+ builderOpState, op.getGetterName(namedAttr.name),
+ constBuildAttrFromParam(attr, fctx, namedAttr.name));
} else {
body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
builderOpState, op.getGetterName(namedAttr.name),
More information about the Mlir-commits
mailing list