[Mlir-commits] [mlir] 6089d61 - [mlir] Prevent implicit downcasting to interfaces
Rahul Kayaith
llvmlistbot at llvm.org
Thu Apr 20 13:32:04 PDT 2023
Author: Rahul Kayaith
Date: 2023-04-20T16:31:54-04:00
New Revision: 6089d612a580738df00c22a43e6f2c29bd216af9
URL: https://github.com/llvm/llvm-project/commit/6089d612a580738df00c22a43e6f2c29bd216af9
DIFF: https://github.com/llvm/llvm-project/commit/6089d612a580738df00c22a43e6f2c29bd216af9.diff
LOG: [mlir] Prevent implicit downcasting to interfaces
Currently conversions to interfaces may happen implicitly (e.g.
`Attribute -> TypedAttr`), failing a runtime assert if the interface
isn't actually implemented. This change marks the `Interface(ValueT)`
constructor as explicit so that a cast is required.
Where it was straightforward to I adjusted code to not require casts,
otherwise I just made them explicit.
Depends on D148491, D148492
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D148493
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/Arith.h
mlir/include/mlir/Dialect/CommonFolders.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/FunctionInterfaces.td
mlir/include/mlir/Support/InterfaceSupport.h
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 3e14e4d346753..f285262982816 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -121,7 +121,7 @@ bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
/// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc);
/// Returns the identity value associated with an AtomicRMWKind op.
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 5633b4ccd7cf6..8007027a86514 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -64,7 +64,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!elementResult)
return {};
- return DenseElementsAttr::get(resultType, *elementResult);
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
}
if (operands[0].isa<ElementsAttr>() && operands[1].isa<ElementsAttr>()) {
@@ -86,7 +86,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
elementResults.push_back(*elementResult);
}
- return DenseElementsAttr::get(resultType, elementResults);
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
}
return {};
}
@@ -233,7 +233,7 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
calculate(op.getSplatValue<ElementValueT>(), castStatus);
if (!castStatus)
return {};
- return DenseElementsAttr::get(resType, elementResult);
+ return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
}
if (operands[0].isa<ElementsAttr>()) {
// Operand is ElementsAttr-derived; perform an element-wise fold by
@@ -250,7 +250,7 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
elementResults.push_back(elt);
}
- return DenseElementsAttr::get(resType, elementResults);
+ return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
}
return {};
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 0d8b89fdf9539..77376ff90cdb2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -132,7 +132,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
/// Return the identity numeric value associated to the give op. Return
/// std::nullopt if there is no known neutral element.
-std::optional<Attribute> getNeutralElement(Operation *op);
+std::optional<TypedAttr> getNeutralElement(Operation *op);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f45781a978900..ae6c284eaf6f9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1618,7 +1618,7 @@ def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol,
if (getAtomicReductionRegion().empty())
return {};
- return getAtomicReductionRegion().front().getArgument(0).getType();
+ return cast<PointerLikeType>(getAtomicReductionRegion().front().getArgument(0).getType());
}
}];
let hasRegionVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index c4f9fa8a6fe05..dab24bd930326 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -20,6 +20,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/TypeRange.h"
#include "mlir/Support/LLVM.h"
// Pull in all enum type definitions and utility function declarations.
@@ -28,8 +29,6 @@
namespace mlir {
class OpBuilder;
-class TypeRange;
-class ValueRange;
class RewriterBase;
/// Tests whether the given maps describe a row major matmul. The test is
@@ -116,6 +115,11 @@ class StructuredGenerator {
// Note: this is a true builder that notifies the OpBuilder listener.
Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
ValueRange newOperands);
+template <typename OpT>
+OpT clone(OpBuilder &b, OpT op, TypeRange newResultTypes,
+ ValueRange newOperands) {
+ return cast<OpT>(clone(b, op.getOperation(), newResultTypes, newOperands));
+}
// Clone the current operation with the operands but leave the regions empty.
// Note: this is a true builder that notifies the OpBuilder listener.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7197b1364bbce..1b0f4bfb3f629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -116,7 +116,7 @@ class Builder {
// Returns a 0-valued attribute of the given `type`. This function only
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
- Attribute getZeroAttr(Type type);
+ TypedAttr getZeroAttr(Type type);
// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index d8506cbcdad10..496c197e47152 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -78,9 +78,9 @@ class DenseElementsAttr : public Attribute {
using Attribute::Attribute;
/// Allow implicit conversion to ElementsAttr.
- operator ElementsAttr() const {
- return *this ? cast<ElementsAttr>() : nullptr;
- }
+ operator ElementsAttr() const { return cast_if_present<ElementsAttr>(*this); }
+ /// Allow implicit conversion to TypedAttr.
+ operator TypedAttr() const { return ElementsAttr(*this); }
/// Type trait used to check if the given type T is a potentially valid C++
/// floating point type that can be used to access the underlying element
@@ -842,9 +842,10 @@ class BoolAttr : public Attribute {
static BoolAttr get(MLIRContext *context, bool value);
- /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
- /// avoid bringing in all of IntegerAttrs methods.
+ /// Enable conversion to IntegerAttr and its interfaces. This uses conversion
+ /// vs. inheritance to avoid bringing in all of IntegerAttrs methods.
operator IntegerAttr() const { return IntegerAttr(impl); }
+ operator TypedAttr() const { return IntegerAttr(impl); }
/// Return the boolean value of this attribute.
bool getValue() const;
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index c30454a5268c1..17bbdcccaed29 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -275,7 +275,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// has less parameters we drop the extra attributes, if there are more
/// parameters they won't have any attributes.
void setType(Type newType) {
- function_interface_impl::setFunctionType(this->getOperation(), newType);
+ function_interface_impl::setFunctionType($_op, newType);
}
//===------------------------------------------------------------------===//
@@ -316,7 +316,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
Type newType = $_op.getTypeWithArgsAndResults(
argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
function_interface_impl::insertFunctionArguments(
- this->getOperation(), argIndices, argTypes, argAttrs, argLocs,
+ $_op, argIndices, argTypes, argAttrs, argLocs,
originalNumArgs, newType);
}
@@ -336,7 +336,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
Type newType = $_op.getTypeWithArgsAndResults(
/*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes);
function_interface_impl::insertFunctionResults(
- this->getOperation(), resultIndices, resultTypes, resultAttrs,
+ $_op, resultIndices, resultTypes, resultAttrs,
originalNumResults, newType);
}
@@ -351,7 +351,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
void eraseArguments(const BitVector &argIndices) {
Type newType = $_op.getTypeWithoutArgs(argIndices);
function_interface_impl::eraseFunctionArguments(
- this->getOperation(), argIndices, newType);
+ $_op, argIndices, newType);
}
/// Erase a single result at `resultIndex`.
@@ -365,7 +365,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
void eraseResults(const BitVector &resultIndices) {
Type newType = $_op.getTypeWithoutResults(resultIndices);
function_interface_impl::eraseFunctionResults(
- this->getOperation(), resultIndices, newType);
+ $_op, resultIndices, newType);
}
/// Return the type of this function with the specified arguments and
@@ -414,7 +414,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
- return function_interface_impl::getArgAttrs(this->getOperation(), index);
+ return function_interface_impl::getArgAttrs($_op, index);
}
/// Return an ArrayAttr containing all argument attribute dictionaries of
@@ -464,11 +464,11 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
}
void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) {
assert(attributes.size() == $_op.getNumArguments());
- function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+ function_interface_impl::setAllArgAttrDicts($_op, attributes);
}
void setAllArgAttrs(ArrayRef<Attribute> attributes) {
assert(attributes.size() == $_op.getNumArguments());
- function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+ function_interface_impl::setAllArgAttrDicts($_op, attributes);
}
void setAllArgAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumArguments());
@@ -503,7 +503,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// Return all of the attributes for the result at 'index'.
ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
- return function_interface_impl::getResultAttrs(this->getOperation(), index);
+ return function_interface_impl::getResultAttrs($_op, index);
}
/// Return an ArrayAttr containing all result attribute dictionaries of this
@@ -554,12 +554,12 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) {
assert(attributes.size() == $_op.getNumResults());
function_interface_impl::setAllResultAttrDicts(
- this->getOperation(), attributes);
+ $_op, attributes);
}
void setAllResultAttrs(ArrayRef<Attribute> attributes) {
assert(attributes.size() == $_op.getNumResults());
function_interface_impl::setAllResultAttrDicts(
- this->getOperation(), attributes);
+ $_op, attributes);
}
void setAllResultAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumResults());
@@ -589,7 +589,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// attribute is returned.
DictionaryAttr getArgAttrDict(unsigned index) {
assert(index < $_op.getNumArguments() && "invalid argument number");
- return function_interface_impl::getArgAttrDict(this->getOperation(), index);
+ return function_interface_impl::getArgAttrDict($_op, index);
}
/// Returns the dictionary attribute corresponding to the result at 'index'.
@@ -597,7 +597,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// returned.
DictionaryAttr getResultAttrDict(unsigned index) {
assert(index < $_op.getNumResults() && "invalid result number");
- return function_interface_impl::getResultAttrDict(this->getOperation(), index);
+ return function_interface_impl::getResultAttrDict($_op, index);
}
}];
diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index 71c4003708356..20e2dfb453ef7 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -91,7 +91,7 @@ class Interface : public BaseType {
};
/// Construct an interface from an instance of the value type.
- Interface(ValueT t = ValueT())
+ explicit Interface(ValueT t = ValueT())
: BaseType(t),
conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
assert((!t || conceptImpl) &&
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b6317053e0a6d..4651c29997f8e 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -296,22 +296,22 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
// on operands zero-extended to i(2*N) bits, and truncate the results back to
// iN types.
if (!resultType.isa<LLVM::LLVMArrayType>()) {
- Type wideType;
// Shift amount necessary to extract the high bits from widened result.
- Attribute shiftValAttr;
+ TypedAttr shiftValAttr;
if (auto intTy = resultType.dyn_cast<IntegerType>()) {
unsigned resultBitwidth = intTy.getWidth();
- wideType = rewriter.getIntegerType(resultBitwidth * 2);
- shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth);
+ auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
+ shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
} else {
auto vecTy = resultType.cast<VectorType>();
unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
- wideType = VectorType::get(vecTy.getShape(),
- rewriter.getIntegerType(resultBitwidth * 2));
+ auto attrTy = VectorType::get(
+ vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
shiftValAttr = SplatElementsAttr::get(
- wideType, APInt(resultBitwidth * 2, resultBitwidth));
+ attrTy, APInt(resultBitwidth * 2, resultBitwidth));
}
+ Type wideType = shiftValAttr.getType();
assert(LLVM::isCompatibleType(wideType) &&
"LLVM dialect should support all signless integer types");
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index fb0cf4f38d79e..1790e3d0212c4 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -40,7 +40,7 @@ Type matchContainerType(Type element, Type container) {
return element;
}
-Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
+TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
Type eTy = shapedTy.getElementType();
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b2e59b24979cb..3f970befa38dc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -625,7 +625,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
-static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
+static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 0.0);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 005ae10fb9fdb..61413b2a6d6ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -31,7 +31,7 @@ using namespace mlir;
using namespace mlir::tosa;
static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
- Attribute padAttr, OpBuilder &rewriter) {
+ TypedAttr padAttr, OpBuilder &rewriter) {
// Input should be padded if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;
@@ -224,7 +224,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto weightShape = weightTy.getShape();
// Apply padding as necessary.
- Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+ TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
int64_t iZp = quantizationInfo.getInputZp();
@@ -269,7 +269,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
- Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
+ auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
@@ -391,7 +391,7 @@ class DepthwiseConvConverter
auto resultShape = resultTy.getShape();
// Apply padding as necessary.
- Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+ TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
@@ -439,7 +439,7 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
- Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
+ auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, linalgConvTy.getShape(), resultETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
@@ -604,7 +604,7 @@ class FullyConnectedConverter
loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
// When quantized, the input elemeny type is not the same as the output
- Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
+ auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
@@ -688,7 +688,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
SmallVector<Value> dynamicDims = *dynamicDimsOr;
// Determine what the initial value needs to be for the max pool op.
- Attribute initialAttr;
+ TypedAttr initialAttr;
if (resultETy.isF32())
initialAttr = rewriter.getFloatAttr(
resultETy,
@@ -768,10 +768,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
pad.resize(2, 0);
llvm::append_range(pad, op.getPad());
pad.resize(pad.size() + 2, 0);
- Attribute padAttr = rewriter.getZeroAttr(inElementTy);
+ TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
- Attribute initialAttr = rewriter.getZeroAttr(accETy);
+ auto initialAttr = rewriter.getZeroAttr(accETy);
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
ArrayRef<int64_t> kernel = op.getKernel();
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 1ef31df2defd7..5e46fab0a1ecd 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -296,7 +296,7 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
loc, padOp.getPadConst(), ValueRange({}));
} else {
- Attribute constantAttr;
+ TypedAttr constantAttr;
if (elementTy.isa<FloatType>()) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
} else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 507b3872090eb..e454567c9213a 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1677,8 +1677,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
dynIdx++;
} else {
// Create ConstantOp for static dimension.
- Attribute constantAttr =
- b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
+ auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
inAffineApply.emplace_back(
b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7c687142247a6..d4c6b8184751f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -21,6 +21,8 @@ def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
// Subtract two integer attributes and createa a new one with the result.
def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
+class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
+
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
@@ -320,8 +322,8 @@ def TruncationMatchesShiftAmount :
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
- (Arith_ShRSIOp $x, (ConstantLikeMatcher AnyAttr:$c0))),
- (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))),
+ (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
+ (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d3ca1987a1707..446bb6461077d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2313,7 +2313,7 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
/// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc) {
switch (kind) {
case AtomicRMWKind::maxf:
@@ -2362,7 +2362,7 @@ Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
/// Returns the identity value associated with an AtomicRMWKind op.
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
OpBuilder &builder, Location loc) {
- Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
+ auto attr = getIdentityValueAttr(op, resultType, builder, loc);
return builder.create<arith::ConstantOp>(loc, attr);
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 15e9b7f430806..22ec425b4730c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -62,7 +62,7 @@ static Type reduceInnermostDim(VectorType type) {
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
- Attribute attr;
+ TypedAttr attr;
if (auto intTy = type.dyn_cast<IntegerType>()) {
attr = rewriter.getIntegerAttr(type, value);
} else {
@@ -989,7 +989,7 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
int64_t pow2Int = int64_t(1) << newBitWidth;
- Attribute pow2Attr =
+ TypedAttr pow2Attr =
rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
if (auto vecTy = dyn_cast<VectorType>(resultTy))
pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 75d92229c287b..1a54614cae33b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -121,7 +121,7 @@ LogicalResult emitc::ConstantOp::verify() {
if (getValueAttr().isa<emitc::OpaqueAttr>())
return success();
- TypedAttr value = getValueAttr();
+ auto value = cast<TypedAttr>(getValueAttr());
Type type = getType();
if (!value.getType().isa<NoneType>() && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
@@ -177,7 +177,7 @@ LogicalResult emitc::VariableOp::verify() {
if (getValueAttr().isa<emitc::OpaqueAttr>())
return success();
- TypedAttr value = getValueAttr();
+ auto value = cast<TypedAttr>(getValueAttr());
Type type = getType();
if (!value.getType().isa<NoneType>() && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index b07bdc91803ee..39c9e5e1725a8 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -492,9 +492,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
llvm::dbgs() << "\n");
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
- auto comparator = [&](DeviceMappingAttrInterface a,
- DeviceMappingAttrInterface b) -> bool {
- return a.getMappingId() < b.getMappingId();
+ auto comparator = [&](Attribute a, Attribute b) -> bool {
+ return cast<DeviceMappingAttrInterface>(a).getMappingId() <
+ cast<DeviceMappingAttrInterface>(b).getMappingId();
};
SmallVector<int64_t> forallMappingSizes =
getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f9265b43eb379..4c8722c29022b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -389,10 +389,7 @@ class RegionBuilderHelper {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
- Type type = NoneType::get(builder.getContext());
- if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
- type = typedAttr.getType();
- return builder.create<arith::ConstantOp>(loc, type, valueAttr);
+ return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
}
Value index(int64_t dim) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 10552ca3406c8..6f9b60843d6d5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -165,7 +165,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
staticStridesVector));
}
- Operation *clonedOp = clone(b, producer, resultTypes, clonedShapes);
+ LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
// Shift all IndexOp results by the tile offset.
SmallVector<OpFoldResult> allIvs = llvm::to_vector(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 8fbf8b247a99a..d8ecc807ea051 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -38,7 +38,7 @@ using namespace linalg;
static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
ArrayRef<int64_t> tiledLoopDims) {
// Get the consumer operand indexing map.
- LinalgOp consumerOp = consumerOperand->getOwner();
+ auto consumerOp = cast<LinalgOp>(consumerOperand->getOwner());
AffineMap indexingMap = consumerOp.getMatchingIndexingMap(consumerOperand);
// Search the slice dimensions tiled by a tile loop dimension.
@@ -65,7 +65,7 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
static SmallVector<int64_t>
getTiledProducerLoops(OpResult producerResult,
ArrayRef<int64_t> tiledSliceDimIndices) {
- LinalgOp producerOp = producerResult.getOwner();
+ auto producerOp = cast<LinalgOp>(producerResult.getOwner());
// Get the indexing map of the `producerOp` output operand that matches
// ´producerResult´.
@@ -137,7 +137,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
b.setInsertionPointAfter(sliceOp);
// Get the producer.
- LinalgOp producerOp = producerResult.getOwner();
+ auto producerOp = cast<LinalgOp>(producerResult.getOwner());
Location loc = producerOp.getLoc();
// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
@@ -345,7 +345,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
return failure();
// Check `sliceOp` and `consumerOp` are in the same block.
- LinalgOp consumerOp = consumerOpOperand->getOwner();
+ auto consumerOp = cast<LinalgOp>(consumerOpOperand->getOwner());
if (sliceOp->getBlock() != rootOp->getBlock() ||
consumerOp->getBlock() != rootOp->getBlock())
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index adc1769bb6053..23c831f0a018a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -198,7 +198,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
"expected the number of loops and induction variables to match");
// Replace the index operations in the body of the innermost loop op.
if (!loopOps.empty()) {
- LoopLikeOpInterface loopOp = loopOps.back();
+ auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
for (IndexOp indexOp :
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index d0e865271a39a..b4d95b70de839 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -66,7 +66,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
Operation *reductionOp = combinerOps[0];
- std::optional<Attribute> identity = getNeutralElement(reductionOp);
+ std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
if (!identity.has_value())
return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
@@ -272,9 +272,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
return b.notifyMatchFailure(op, "cannot match a reduction pattern");
- SmallVector<Attribute> neutralElements;
+ SmallVector<TypedAttr> neutralElements;
for (Operation *reductionOp : combinerOps) {
- std::optional<Attribute> neutralElement = getNeutralElement(reductionOp);
+ std::optional<TypedAttr> neutralElement = getNeutralElement(reductionOp);
if (!neutralElement.has_value())
return b.notifyMatchFailure(op, "cannot find neutral element.");
neutralElements.push_back(*neutralElement);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c5026eed4423e..1c3745f66cbf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -271,7 +271,7 @@ struct LinalgOpPartialReductionInterface
return op->emitOpError("Failed to anaysis the reduction operation.");
Operation *reductionOp = combinerOps[0];
- std::optional<Attribute> identity = getNeutralElement(reductionOp);
+ std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
if (!identity.has_value())
return op->emitOpError(
"Failed to get an identity value for the reduction operation.");
@@ -328,8 +328,8 @@ struct LinalgOpPartialReductionInterface
// Step 1: Extract a slice of the input operands.
SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
- SmallVector<Value, 4> tiledOperands =
- makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true);
+ SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+ b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
// Step 2: Extract the accumulator operands
SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dbda4356c81f0..166f42637523f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -83,11 +83,8 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
}
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
- Type paddingType = rewriter.getType<NoneType>();
- if (auto typedAttr = paddingAttr.dyn_cast<TypedAttr>())
- paddingType = typedAttr.getType();
Value paddingValue = rewriter.create<arith::ConstantOp>(
- opToPad.getLoc(), paddingType, paddingAttr);
+ opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
OpOperand *currOpOperand = opOperand;
@@ -576,7 +573,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
rewriter, loc, operand, innerPackSizes, innerPos,
/*outerDimsPerm=*/{});
// TODO: value of the padding attribute should be determined by consumers.
- Attribute zeroAttr =
+ auto zeroAttr =
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
packOps.push_back(rewriter.create<tensor::PackOp>(
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index f4969baa10e39..5e3413accf7c3 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -985,7 +985,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
}
/// Return the identity numeric value associated to the give op.
-std::optional<Attribute> getNeutralElement(Operation *op) {
+std::optional<TypedAttr> getNeutralElement(Operation *op) {
// Builder only used as helper for attribute creation.
OpBuilder b(op->getContext());
Type resultType = op->getResult(0).getType();
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 8f0f8dab77e22..6d286a31290e6 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1245,7 +1245,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
floatTy = broadcast(floatTy, shape);
intTy = broadcast(intTy, shape);
- auto bconst = [&](Attribute attr) -> Value {
+ auto bconst = [&](TypedAttr attr) -> Value {
Value value = b.create<arith::ConstantOp>(attr);
return broadcast(b, value, shape);
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index de0766daeff30..38fb11348f285 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -111,9 +111,9 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
- sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
- size =
- rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
+ auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
+ size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
+ sizes[i] = sizeAttr;
}
strides[i] = stride;
if (i > 0)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 0d4bff8aa6e09..3cd4937e96f26 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -44,7 +44,7 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
if (auto intTy = type.dyn_cast<IntegerType>())
return IntegerAttr::get(intTy, sizedValue);
- return SplatElementsAttr::get(type, sizedValue);
+ return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}
Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index cbf591372f9af..3a488b311b95a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -213,7 +213,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
-mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
+mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
if (tp.isa<FloatType>())
return builder.getFloatAttr(tp, 1.0);
if (tp.isa<IndexType>())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 47c581c8d88ca..b6e6def4e5860 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -79,7 +79,7 @@ Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
/// returning a null attribute.
-Attribute getOneAttr(Builder &builder, Type tp);
+TypedAttr getOneAttr(Builder &builder, Type tp);
/// Generates the comparison `v != 0` where `v` is of numeric type.
/// For floating types, we use the "unordered" comparator (i.e., returns
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e214820c2f47f..c8cf85e695749 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1729,7 +1729,7 @@ class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
if (!splat)
return failure();
- Attribute newAttr = splat.getSplatValue<Attribute>();
+ TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
@@ -1767,9 +1767,9 @@ class ExtractOpNonSplatConstantFolder final
copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
- auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
+ auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
- Attribute newAttr;
+ TypedAttr newAttr;
if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
SmallVector<Attribute> elementValues(
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index e318d4dc15915..3f26558237a2f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -191,7 +191,7 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
- MaskableOpInterface maskableOp = maskOp.getMaskableOp();
+ auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6592930b5f66f..2b5706aaa7748 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -692,8 +692,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
- Attribute newAttr = DenseElementsAttr::get(
- warpOp.getResult(operandIndex).getType(), scalarAttr);
+ auto newAttr = DenseElementsAttr::get(
+ cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
Location loc = warpOp.getLoc();
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 1aee27560ea35..e806db7b7cdef 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -79,7 +79,7 @@ struct MaskCompressOpConversion
src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
op.getConstantSrcAttr());
} else {
- Attribute zeroAttr = rewriter.getZeroAttr(opType);
+ auto zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 9203943123470..7943655aa1b89 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -315,7 +315,7 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
return getArrayAttr(attrs);
}
-Attribute Builder::getZeroAttr(Type type) {
+TypedAttr Builder::getZeroAttr(Type type) {
if (type.isa<FloatType>())
return getFloatAttr(type, 0.0);
if (type.isa<IndexType>())
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 7e95ca137fdd1..0810d8965b385 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -539,7 +539,7 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
elementType.getContext());
// Wrap AffineMap into Attribute.
- Attribute layout = AffineMapAttr::get(map);
+ auto layout = AffineMapAttr::get(map);
// Drop default memory space value and replace it with empty attribute.
memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -559,7 +559,7 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
elementType.getContext());
// Wrap AffineMap into Attribute.
- Attribute layout = AffineMapAttr::get(map);
+ auto layout = AffineMapAttr::get(map);
// Drop default memory space value and replace it with empty attribute.
memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -577,7 +577,7 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
elementType.getContext());
// Wrap AffineMap into Attribute.
- Attribute layout = AffineMapAttr::get(map);
+ auto layout = AffineMapAttr::get(map);
// Convert deprecated integer-like memory space to Attribute.
Attribute memorySpace =
@@ -598,7 +598,7 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
elementType.getContext());
// Wrap AffineMap into Attribute.
- Attribute layout = AffineMapAttr::get(map);
+ auto layout = AffineMapAttr::get(map);
// Convert deprecated integer-like memory space to Attribute.
Attribute memorySpace =
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 8b0581d8be04c..4c3713fa9a75b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -535,7 +535,7 @@ std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
spirv::SpecConstantOp
spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
- Attribute defaultValue) {
+ TypedAttr defaultValue) {
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
defaultValue);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 9d41166b0fb7a..487b66769390e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -220,7 +220,7 @@ class Deserializer {
/// Creates a spirv::SpecConstantOp.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
- Attribute defaultValue);
+ TypedAttr defaultValue);
/// Processes the OpVariable instructions at current `offset` into `binary`.
/// It is expected that this method is used for variables that are to be
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 0c34ccb47257e..46126b0c5f400 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -222,7 +222,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion, Operation *call = nullptr) {
+ bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
if (src->empty())
@@ -328,7 +328,7 @@ static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion, Operation *call = nullptr) {
+ bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index f00febcfa9435..50504988689b0 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -50,7 +50,7 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
- *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+ *originalOpInLinalgOpsVector = info->fusedProducer;
// Don't mark for erasure in the tensor case, let DCE handle this.
changed = true;
}
More information about the Mlir-commits
mailing list