[Mlir-commits] [mlir] bbe5bf1 - Cleanup uses of getAttrDictionary() in MLIR to use getDiscardableAttrDictionary() when possible
Mehdi Amini
llvmlistbot at llvm.org
Mon May 15 11:36:43 PDT 2023
Author: Mehdi Amini
Date: 2023-05-15T11:35:50-07:00
New Revision: bbe5bf1788b55e3c7020d50ee0fd5956f261cfec
URL: https://github.com/llvm/llvm-project/commit/bbe5bf1788b55e3c7020d50ee0fd5956f261cfec
DIFF: https://github.com/llvm/llvm-project/commit/bbe5bf1788b55e3c7020d50ee0fd5956f261cfec.diff
LOG: Cleanup uses of getAttrDictionary() in MLIR to use getDiscardableAttrDictionary() when possible
This also speeds up some benchmarks in compiling simple fortan file by 2x!
Fixes #62687
Differential Revision: https://reviews.llvm.org/D150540
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/IR/TestOperationEquals.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f17a2e54a731..1a362f602b53 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -147,11 +147,11 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const final {
if constexpr (SourceOp::hasProperties())
rewrite(cast<SourceOp>(op),
- OpAdaptor(operands, op->getAttrDictionary(),
+ OpAdaptor(operands, op->getDiscardableAttrDictionary(),
cast<SourceOp>(op).getProperties()),
rewriter);
- rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
- rewriter);
+ rewrite(cast<SourceOp>(op),
+ OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
@@ -161,12 +161,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const final {
if constexpr (SourceOp::hasProperties())
return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands, op->getAttrDictionary(),
+ OpAdaptor(operands,
+ op->getDiscardableAttrDictionary(),
cast<SourceOp>(op).getProperties()),
rewriter);
- return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands, op->getAttrDictionary()),
- rewriter);
+ return matchAndRewrite(
+ cast<SourceOp>(op),
+ OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index d08d3de350f3..71864def4543 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1890,11 +1890,12 @@ class Op : public OpState, public Traits<ConcreteType>... {
if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
if constexpr (hasProperties()) {
result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getAttrDictionary(),
+ operands, op->getDiscardableAttrDictionary(),
cast<ConcreteOpT>(op).getProperties(), op->getRegions()));
} else {
result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getAttrDictionary(), {}, op->getRegions()));
+ operands, op->getDiscardableAttrDictionary(), {},
+ op->getRegions()));
}
} else {
result = cast<ConcreteOpT>(op).fold(operands);
@@ -1920,13 +1921,14 @@ class Op : public OpState, public Traits<ConcreteType>... {
if constexpr (hasProperties()) {
result = cast<ConcreteOpT>(op).fold(
typename ConcreteOpT::FoldAdaptor(
- operands, op->getAttrDictionary(),
+ operands, op->getDiscardableAttrDictionary(),
cast<ConcreteOpT>(op).getProperties(), op->getRegions()),
results);
} else {
result = cast<ConcreteOpT>(op).fold(
- typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
- {}, op->getRegions()),
+ typename ConcreteOpT::FoldAdaptor(
+ operands, op->getDiscardableAttrDictionary(), {},
+ op->getRegions()),
results);
}
} else {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 020c8ce9ab4e..f242eea76778 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -520,7 +520,10 @@ class OpConversionPattern : public ConversionPattern {
}
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp,
+ OpAdaptor(operands, op->getDiscardableAttrDictionary(),
+ sourceOp.getProperties()),
rewriter);
}
LogicalResult
@@ -529,11 +532,13 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
if constexpr (SourceOp::hasProperties())
return matchAndRewrite(sourceOp,
- OpAdaptor(operands, op->getAttrDictionary(),
+ OpAdaptor(operands,
+ op->getDiscardableAttrDictionary(),
sourceOp.getProperties()),
rewriter);
return matchAndRewrite(
- sourceOp, OpAdaptor(operands, op->getAttrDictionary()), rewriter);
+ sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()),
+ rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 17b9b7404768..47c2cdbf5260 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -132,8 +132,11 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- return matchAndRewrite(cast<memref::ReallocOp>(op),
- OpAdaptor(operands, op->getAttrDictionary()),
+ auto reallocOp = cast<memref::ReallocOp>(op);
+ return matchAndRewrite(reallocOp,
+ OpAdaptor(operands,
+ op->getDiscardableAttrDictionary(),
+ reallocOp.getProperties()),
rewriter);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 1fbe66ff98d7..40903f199afd 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -111,10 +111,10 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
resultTypes.reserve(1 + op->getNumResults());
copy(op->getResultTypes(), std::back_inserter(resultTypes));
resultTypes.push_back(tokenType);
- auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
- op->getOperands(), op->getAttrDictionary(),
- op->getPropertiesStorage(),
- op->getSuccessors(), op->getNumRegions());
+ auto *newOp = Operation::create(
+ op->getLoc(), op->getName(), resultTypes, op->getOperands(),
+ op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
+ op->getSuccessors(), op->getNumRegions());
// Clone regions into new op.
IRMapping mapping;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 3aa1c3f1c2c5..36e967d2d578 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -393,7 +393,9 @@ struct ConvertSelectionOpToSelect
}
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
- return lhs->getAttrDictionary() == rhs->getAttrDictionary();
+ return lhs->getDiscardableAttrDictionary() ==
+ rhs->getDiscardableAttrDictionary() &&
+ lhs->hashProperties() == rhs->hashProperties();
}
// Returns a source value for the given block.
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 1a056a0d597c..24302544ea06 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -920,7 +920,8 @@ LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Builder b(context);
- auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
+ Properties *prop = properties.as<Properties *>();
+ DenseIntElementsAttr shape = prop->shape;
if (!shape)
return emitOptionalError(location, "missing shape attribute");
inferredReturnTypes.assign({RankedTensorType::get(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1b063e75eb93..1040d4c15965 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -383,7 +383,8 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
- IntegerAttr axis = llvm::cast<IntegerAttr>(attributes.get("axis"));
+ auto *prop = properties.as<Properties *>();
+ IntegerAttr axis = prop->axis;
int32_t axisVal = axis.getValue().getSExtValue();
if (!inputShape.hasRank()) {
@@ -446,8 +447,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Infer all dimension sizes by reducing based on inputs.
- int32_t axis =
- llvm::cast<IntegerAttr>(attributes.get("axis")).getValue().getSExtValue();
+ auto *prop = properties.as<Properties *>();
+ int32_t axis = prop->axis.getValue().getSExtValue();
llvm::SmallVector<int64_t> outputShape;
bool hasRankedInput = false;
for (auto operand : operands) {
@@ -985,7 +986,7 @@ static LogicalResult ReduceInferReturnTypes(
Type inputType = \
operands.getType()[0].cast<TensorType>().getElementType(); \
return ReduceInferReturnTypes(operands.getShape(0), inputType, \
- attributes.get("axis").cast<IntegerAttr>(), \
+ properties.as<Properties *>()->axis, \
inferredReturnShapes); \
} \
COMPATIBLE_RETURN_TYPES(OP)
@@ -1062,6 +1063,7 @@ NARY_SHAPE_INFER(tosa::SigmoidOp)
static LogicalResult poolingInferReturnTypes(
const ValueShapeRange &operands, DictionaryAttr attributes,
+ ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, ArrayRef<int64_t> pad,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
llvm::SmallVector<int64_t> outputShape;
@@ -1080,12 +1082,6 @@ static LogicalResult poolingInferReturnTypes(
int64_t height = inputShape.getDimSize(1);
int64_t width = inputShape.getDimSize(2);
- ArrayRef<int64_t> kernel =
- llvm::cast<DenseI64ArrayAttr>(attributes.get("kernel"));
- ArrayRef<int64_t> stride =
- llvm::cast<DenseI64ArrayAttr>(attributes.get("stride"));
- ArrayRef<int64_t> pad = llvm::cast<DenseI64ArrayAttr>(attributes.get("pad"));
-
if (!ShapedType::isDynamic(height)) {
int64_t padded = height + pad[0] + pad[1] - kernel[0];
outputShape[1] = padded / stride[0] + 1;
@@ -1245,7 +1241,9 @@ LogicalResult AvgPool2dOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
+ Properties &prop = *properties.as<Properties *>();
+ return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
+ prop.pad, inferredReturnShapes);
}
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
@@ -1253,7 +1251,9 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
+ Properties &prop = *properties.as<Properties *>();
+ return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
+ prop.pad, inferredReturnShapes);
}
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 87563c1761a8..50a556dfc694 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -37,10 +37,10 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
- .inferReturnTypeComponents(op.getContext(), op.getLoc(),
- op->getOperands(), op->getAttrDictionary(),
- op->getPropertiesStorage(),
- op->getRegions(), returnedShapes)
+ .inferReturnTypeComponents(
+ op.getContext(), op.getLoc(), op->getOperands(),
+ op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
+ op->getRegions(), returnedShapes)
.failed())
return op;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 3e2da9df3f94..65b66d29d6f8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -218,9 +218,10 @@ void propagateShapesInRegion(Region ®ion) {
ValueShapeRange range(op.getOperands(), operandShape);
if (shapeInterface
- .inferReturnTypeComponents(
- op.getContext(), op.getLoc(), range, op.getAttrDictionary(),
- op.getPropertiesStorage(), op.getRegions(), returnedShapes)
+ .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
+ op.getDiscardableAttrDictionary(),
+ op.getPropertiesStorage(),
+ op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index b8f360122239..716239af863c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -653,7 +653,7 @@ llvm::hash_code OperationEquivalence::computeHash(
// - Attributes
// - Result Types
llvm::hash_code hash =
- llvm::hash_combine(op->getName(), op->getAttrDictionary(),
+ llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
op->getResultTypes(), op->hashProperties());
// - Operands
@@ -768,11 +768,13 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
// 1. Compare the operation properties.
if (lhs->getName() != rhs->getName() ||
- lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
+ lhs->getDiscardableAttrDictionary() !=
+ rhs->getDiscardableAttrDictionary() ||
lhs->getNumRegions() != rhs->getNumRegions() ||
lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
lhs->getNumOperands() != rhs->getNumOperands() ||
- lhs->getNumResults() != rhs->getNumResults())
+ lhs->getNumResults() != rhs->getNumResults() ||
+ lhs->hashProperties() != rhs->hashProperties())
return false;
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
@@ -876,7 +878,9 @@ OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
// - Operation pointer
addDataToHash(hasher, op);
// - Attributes
- addDataToHash(hasher, op->getAttrDictionary());
+ addDataToHash(hasher, op->getDiscardableAttrDictionary());
+ // - Properties
+ addDataToHash(hasher, op->hashProperties());
// - Blocks in Regions
for (Region ®ion : op->getRegions()) {
for (Block &block : region) {
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 68e498d57324..a7f84beaa916 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -174,7 +174,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
return op.emitError("null operand found");
/// Verify that all of the attributes are okay.
- for (auto attr : op.getAttrs()) {
+ for (auto attr : op.getDiscardableAttrDictionary()) {
// Check for any optional dialect specific attributes.
if (auto *dialect = attr.getNameDialect())
if (failed(dialect->verifyOperationAttribute(&op, attr)))
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 80ed2cc3f22e..aaa1e1b24525 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -251,8 +251,8 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
auto retTypeFn = cast<InferTypeOpInterface>(op);
auto result = retTypeFn.refineReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
- op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(),
- inferredReturnTypes);
+ op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
+ op->getRegions(), inferredReturnTypes);
if (failed(result))
op->emitOpError() << "failed to infer returned types";
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 82ae72ab6a27..3a1faeabe84c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -436,7 +436,7 @@ static void invokeCreateWithInferredReturnType(Operation *op) {
std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded(OpTy::inferReturnTypes(
- context, std::nullopt, values, op->getAttrDictionary(),
+ context, std::nullopt, values, op->getDiscardableAttrDictionary(),
op->getPropertiesStorage(), op->getRegions(),
inferredReturnTypes))) {
OperationState state(location, OpTy::getOperationName());
diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
index ef3589636e52..03cf5f4facf8 100644
--- a/mlir/test/lib/IR/TestOperationEquals.cpp
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -31,7 +31,7 @@ struct TestOperationEqualPass
Operation *first = &module.getBody()->front();
llvm::outs() << first->getName().getStringRef() << " with attr "
- << first->getAttrDictionary();
+ << first->getDiscardableAttrDictionary();
OperationEquivalence::Flags flags{};
if (!first->hasAttr("strict_loc_check"))
flags |= OperationEquivalence::IgnoreLocations;
More information about the Mlir-commits
mailing list