[Mlir-commits] [mlir] 6089d61 - [mlir] Prevent implicit downcasting to interfaces

Rahul Kayaith llvmlistbot at llvm.org
Thu Apr 20 13:32:04 PDT 2023


Author: Rahul Kayaith
Date: 2023-04-20T16:31:54-04:00
New Revision: 6089d612a580738df00c22a43e6f2c29bd216af9

URL: https://github.com/llvm/llvm-project/commit/6089d612a580738df00c22a43e6f2c29bd216af9
DIFF: https://github.com/llvm/llvm-project/commit/6089d612a580738df00c22a43e6f2c29bd216af9.diff

LOG: [mlir] Prevent implicit downcasting to interfaces

Currently conversions to interfaces may happen implicitly (e.g.
`Attribute -> TypedAttr`), failing a runtime assert if the interface
isn't actually implemented. This change marks the `Interface(ValueT)`
constructor as explicit so that a cast is required.

Where it was straightforward to I adjusted code to not require casts,
otherwise I just made them explicit.

Depends on D148491, D148492

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D148493

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/Arith.h
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/FunctionInterfaces.td
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Transforms/Utils/InliningUtils.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 3e14e4d346753..f285262982816 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -121,7 +121,7 @@ bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
                        const APFloat &rhs);
 
 /// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                OpBuilder &builder, Location loc);
 
 /// Returns the identity value associated with an AtomicRMWKind op.

diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 5633b4ccd7cf6..8007027a86514 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -64,7 +64,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     if (!elementResult)
       return {};
 
-    return DenseElementsAttr::get(resultType, *elementResult);
+    return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
   }
 
   if (operands[0].isa<ElementsAttr>() && operands[1].isa<ElementsAttr>()) {
@@ -86,7 +86,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
       elementResults.push_back(*elementResult);
     }
 
-    return DenseElementsAttr::get(resultType, elementResults);
+    return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
   }
   return {};
 }
@@ -233,7 +233,7 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
         calculate(op.getSplatValue<ElementValueT>(), castStatus);
     if (!castStatus)
       return {};
-    return DenseElementsAttr::get(resType, elementResult);
+    return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
   }
   if (operands[0].isa<ElementsAttr>()) {
     // Operand is ElementsAttr-derived; perform an element-wise fold by
@@ -250,7 +250,7 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
       elementResults.push_back(elt);
     }
 
-    return DenseElementsAttr::get(resType, elementResults);
+    return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
   }
   return {};
 }

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 0d8b89fdf9539..77376ff90cdb2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -132,7 +132,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
 /// Return the identity numeric value associated to the give op. Return
 /// std::nullopt if there is no known neutral element.
-std::optional<Attribute> getNeutralElement(Operation *op);
+std::optional<TypedAttr> getNeutralElement(Operation *op);
 
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f45781a978900..ae6c284eaf6f9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1618,7 +1618,7 @@ def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol,
       if (getAtomicReductionRegion().empty())
         return {};
 
-      return getAtomicReductionRegion().front().getArgument(0).getType();
+      return cast<PointerLikeType>(getAtomicReductionRegion().front().getArgument(0).getType());
     }
   }];
   let hasRegionVerifier = 1;

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index c4f9fa8a6fe05..dab24bd930326 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/TypeRange.h"
 #include "mlir/Support/LLVM.h"
 
 // Pull in all enum type definitions and utility function declarations.
@@ -28,8 +29,6 @@
 namespace mlir {
 
 class OpBuilder;
-class TypeRange;
-class ValueRange;
 class RewriterBase;
 
 /// Tests whether the given maps describe a row major matmul. The test is
@@ -116,6 +115,11 @@ class StructuredGenerator {
 // Note: this is a true builder that notifies the OpBuilder listener.
 Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
                  ValueRange newOperands);
+template <typename OpT>
+OpT clone(OpBuilder &b, OpT op, TypeRange newResultTypes,
+          ValueRange newOperands) {
+  return cast<OpT>(clone(b, op.getOperation(), newResultTypes, newOperands));
+}
 
 // Clone the current operation with the operands but leave the regions empty.
 // Note: this is a true builder that notifies the OpBuilder listener.

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7197b1364bbce..1b0f4bfb3f629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -116,7 +116,7 @@ class Builder {
   // Returns a 0-valued attribute of the given `type`. This function only
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
-  Attribute getZeroAttr(Type type);
+  TypedAttr getZeroAttr(Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index d8506cbcdad10..496c197e47152 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -78,9 +78,9 @@ class DenseElementsAttr : public Attribute {
   using Attribute::Attribute;
 
   /// Allow implicit conversion to ElementsAttr.
-  operator ElementsAttr() const {
-    return *this ? cast<ElementsAttr>() : nullptr;
-  }
+  operator ElementsAttr() const { return cast_if_present<ElementsAttr>(*this); }
+  /// Allow implicit conversion to TypedAttr.
+  operator TypedAttr() const { return ElementsAttr(*this); }
 
   /// Type trait used to check if the given type T is a potentially valid C++
   /// floating point type that can be used to access the underlying element
@@ -842,9 +842,10 @@ class BoolAttr : public Attribute {
 
   static BoolAttr get(MLIRContext *context, bool value);
 
-  /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
-  /// avoid bringing in all of IntegerAttrs methods.
+  /// Enable conversion to IntegerAttr and its interfaces. This uses conversion
+  /// vs. inheritance to avoid bringing in all of IntegerAttrs methods.
   operator IntegerAttr() const { return IntegerAttr(impl); }
+  operator TypedAttr() const { return IntegerAttr(impl); }
 
   /// Return the boolean value of this attribute.
   bool getValue() const;

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index c30454a5268c1..17bbdcccaed29 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -275,7 +275,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     ///    has less parameters we drop the extra attributes, if there are more
     ///    parameters they won't have any attributes.
     void setType(Type newType) {
-      function_interface_impl::setFunctionType(this->getOperation(), newType);
+      function_interface_impl::setFunctionType($_op, newType);
     }
 
     //===------------------------------------------------------------------===//
@@ -316,7 +316,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
       Type newType = $_op.getTypeWithArgsAndResults(
           argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
       function_interface_impl::insertFunctionArguments(
-          this->getOperation(), argIndices, argTypes, argAttrs, argLocs,
+          $_op, argIndices, argTypes, argAttrs, argLocs,
           originalNumArgs, newType);
     }
 
@@ -336,7 +336,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
       Type newType = $_op.getTypeWithArgsAndResults(
         /*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes);
       function_interface_impl::insertFunctionResults(
-          this->getOperation(), resultIndices, resultTypes, resultAttrs,
+          $_op, resultIndices, resultTypes, resultAttrs,
           originalNumResults, newType);
     }
 
@@ -351,7 +351,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     void eraseArguments(const BitVector &argIndices) {
       Type newType = $_op.getTypeWithoutArgs(argIndices);
       function_interface_impl::eraseFunctionArguments(
-        this->getOperation(), argIndices, newType);
+        $_op, argIndices, newType);
     }
 
     /// Erase a single result at `resultIndex`.
@@ -365,7 +365,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     void eraseResults(const BitVector &resultIndices) {
       Type newType = $_op.getTypeWithoutResults(resultIndices);
       function_interface_impl::eraseFunctionResults(
-          this->getOperation(), resultIndices, newType);
+          $_op, resultIndices, newType);
     }
 
     /// Return the type of this function with the specified arguments and
@@ -414,7 +414,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
 
     /// Return all of the attributes for the argument at 'index'.
     ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
-      return function_interface_impl::getArgAttrs(this->getOperation(), index);
+      return function_interface_impl::getArgAttrs($_op, index);
     }
 
     /// Return an ArrayAttr containing all argument attribute dictionaries of
@@ -464,11 +464,11 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     }
     void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) {
       assert(attributes.size() == $_op.getNumArguments());
-      function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+      function_interface_impl::setAllArgAttrDicts($_op, attributes);
     }
     void setAllArgAttrs(ArrayRef<Attribute> attributes) {
       assert(attributes.size() == $_op.getNumArguments());
-      function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+      function_interface_impl::setAllArgAttrDicts($_op, attributes);
     }
     void setAllArgAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumArguments());
@@ -503,7 +503,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
 
     /// Return all of the attributes for the result at 'index'.
     ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
-      return function_interface_impl::getResultAttrs(this->getOperation(), index);
+      return function_interface_impl::getResultAttrs($_op, index);
     }
 
     /// Return an ArrayAttr containing all result attribute dictionaries of this
@@ -554,12 +554,12 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) {
       assert(attributes.size() == $_op.getNumResults());
       function_interface_impl::setAllResultAttrDicts(
-        this->getOperation(), attributes);
+        $_op, attributes);
     }
     void setAllResultAttrs(ArrayRef<Attribute> attributes) {
       assert(attributes.size() == $_op.getNumResults());
       function_interface_impl::setAllResultAttrDicts(
-        this->getOperation(), attributes);
+        $_op, attributes);
     }
     void setAllResultAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumResults());
@@ -589,7 +589,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     /// attribute is returned.
     DictionaryAttr getArgAttrDict(unsigned index) {
       assert(index < $_op.getNumArguments() && "invalid argument number");
-      return function_interface_impl::getArgAttrDict(this->getOperation(), index);
+      return function_interface_impl::getArgAttrDict($_op, index);
     }
 
     /// Returns the dictionary attribute corresponding to the result at 'index'.
@@ -597,7 +597,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
     /// returned.
     DictionaryAttr getResultAttrDict(unsigned index) {
       assert(index < $_op.getNumResults() && "invalid result number");
-      return function_interface_impl::getResultAttrDict(this->getOperation(), index);
+      return function_interface_impl::getResultAttrDict($_op, index);
     }
   }];
 

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index 71c4003708356..20e2dfb453ef7 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -91,7 +91,7 @@ class Interface : public BaseType {
   };
 
   /// Construct an interface from an instance of the value type.
-  Interface(ValueT t = ValueT())
+  explicit Interface(ValueT t = ValueT())
       : BaseType(t),
         conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
     assert((!t || conceptImpl) &&

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b6317053e0a6d..4651c29997f8e 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -296,22 +296,22 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
   // on operands zero-extended to i(2*N) bits, and truncate the results back to
   // iN types.
   if (!resultType.isa<LLVM::LLVMArrayType>()) {
-    Type wideType;
     // Shift amount necessary to extract the high bits from widened result.
-    Attribute shiftValAttr;
+    TypedAttr shiftValAttr;
 
     if (auto intTy = resultType.dyn_cast<IntegerType>()) {
       unsigned resultBitwidth = intTy.getWidth();
-      wideType = rewriter.getIntegerType(resultBitwidth * 2);
-      shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth);
+      auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
+      shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
     } else {
       auto vecTy = resultType.cast<VectorType>();
       unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
-      wideType = VectorType::get(vecTy.getShape(),
-                                 rewriter.getIntegerType(resultBitwidth * 2));
+      auto attrTy = VectorType::get(
+          vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
       shiftValAttr = SplatElementsAttr::get(
-          wideType, APInt(resultBitwidth * 2, resultBitwidth));
+          attrTy, APInt(resultBitwidth * 2, resultBitwidth));
     }
+    Type wideType = shiftValAttr.getType();
     assert(LLVM::isCompatibleType(wideType) &&
            "LLVM dialect should support all signless integer types");
 

diff  --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index fb0cf4f38d79e..1790e3d0212c4 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -40,7 +40,7 @@ Type matchContainerType(Type element, Type container) {
   return element;
 }
 
-Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
+TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
   if (auto shapedTy = type.dyn_cast<ShapedType>()) {
     Type eTy = shapedTy.getElementType();
     APInt valueInt(eTy.getIntOrFloatBitWidth(), value);

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b2e59b24979cb..3f970befa38dc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -625,7 +625,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
 
 // Returns the constant initial value for a given reduction operation. The
 // attribute type varies depending on the element type required.
-static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
+static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
                                                PatternRewriter &rewriter) {
   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
     return rewriter.getFloatAttr(elementTy, 0.0);

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 005ae10fb9fdb..61413b2a6d6ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -31,7 +31,7 @@ using namespace mlir;
 using namespace mlir::tosa;
 
 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
-                            Attribute padAttr, OpBuilder &rewriter) {
+                            TypedAttr padAttr, OpBuilder &rewriter) {
   // Input should be padded if necessary.
   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
     return input;
@@ -224,7 +224,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     auto weightShape = weightTy.getShape();
 
     // Apply padding as necessary.
-    Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+    TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
     if (isQuantized) {
       auto quantizationInfo = *op.getQuantizationInfo();
       int64_t iZp = quantizationInfo.getInputZp();
@@ -269,7 +269,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
                                                 weightPermValue);
 
-    Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, resultTy.getShape(), resultETy, filteredDims);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
@@ -391,7 +391,7 @@ class DepthwiseConvConverter
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
-    Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+    TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
@@ -439,7 +439,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, linalgConvTy.getShape(), resultETy, filteredDims);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
@@ -604,7 +604,7 @@ class FullyConnectedConverter
         loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
 
     // When quantized, the input elemeny type is not the same as the output
-    Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
+    auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor = rewriter
                            .create<linalg::FillOp>(loc, ValueRange{zero},
@@ -688,7 +688,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     SmallVector<Value> dynamicDims = *dynamicDimsOr;
 
     // Determine what the initial value needs to be for the max pool op.
-    Attribute initialAttr;
+    TypedAttr initialAttr;
     if (resultETy.isF32())
       initialAttr = rewriter.getFloatAttr(
           resultETy,
@@ -768,10 +768,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
     pad.resize(2, 0);
     llvm::append_range(pad, op.getPad());
     pad.resize(pad.size() + 2, 0);
-    Attribute padAttr = rewriter.getZeroAttr(inElementTy);
+    TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
     Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
 
-    Attribute initialAttr = rewriter.getZeroAttr(accETy);
+    auto initialAttr = rewriter.getZeroAttr(accETy);
     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
 
     ArrayRef<int64_t> kernel = op.getKernel();

diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 1ef31df2defd7..5e46fab0a1ecd 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -296,7 +296,7 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
       padConstant = rewriter.createOrFold<tensor::ExtractOp>(
           loc, padOp.getPadConst(), ValueRange({}));
     } else {
-      Attribute constantAttr;
+      TypedAttr constantAttr;
       if (elementTy.isa<FloatType>()) {
         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
       } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 507b3872090eb..e454567c9213a 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1677,8 +1677,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
       dynIdx++;
     } else {
       // Create ConstantOp for static dimension.
-      Attribute constantAttr =
-          b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
+      auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
       inAffineApply.emplace_back(
           b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
     }

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7c687142247a6..d4c6b8184751f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -21,6 +21,8 @@ def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
 // Subtract two integer attributes and createa a new one with the result.
 def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 
+class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
+
 //===----------------------------------------------------------------------===//
 // AddIOp
 //===----------------------------------------------------------------------===//
@@ -320,8 +322,8 @@ def TruncationMatchesShiftAmount :
 // trunci(shrsi(x, c)) -> trunci(shrui(x, c))
 def TruncIShrSIToTrunciShrUI :
     Pat<(Arith_TruncIOp:$tr
-          (Arith_ShRSIOp $x, (ConstantLikeMatcher AnyAttr:$c0))),
-        (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))),
+          (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
+        (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
         [(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
 
 // trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d3ca1987a1707..446bb6461077d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2313,7 +2313,7 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                             OpBuilder &builder, Location loc) {
   switch (kind) {
   case AtomicRMWKind::maxf:
@@ -2362,7 +2362,7 @@ Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
 /// Returns the identity value associated with an AtomicRMWKind op.
 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
                                     OpBuilder &builder, Location loc) {
-  Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
+  auto attr = getIdentityValueAttr(op, resultType, builder, loc);
   return builder.create<arith::ConstantOp>(loc, attr);
 }
 

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 15e9b7f430806..22ec425b4730c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -62,7 +62,7 @@ static Type reduceInnermostDim(VectorType type) {
 static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
                                          Location loc, Type type,
                                          const APInt &value) {
-  Attribute attr;
+  TypedAttr attr;
   if (auto intTy = type.dyn_cast<IntegerType>()) {
     attr = rewriter.getIntegerAttr(type, value);
   } else {
@@ -989,7 +989,7 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
     Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
 
     int64_t pow2Int = int64_t(1) << newBitWidth;
-    Attribute pow2Attr =
+    TypedAttr pow2Attr =
         rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
     if (auto vecTy = dyn_cast<VectorType>(resultTy))
       pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 75d92229c287b..1a54614cae33b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -121,7 +121,7 @@ LogicalResult emitc::ConstantOp::verify() {
   if (getValueAttr().isa<emitc::OpaqueAttr>())
     return success();
 
-  TypedAttr value = getValueAttr();
+  auto value = cast<TypedAttr>(getValueAttr());
   Type type = getType();
   if (!value.getType().isa<NoneType>() && type != value.getType())
     return emitOpError() << "requires attribute's type (" << value.getType()
@@ -177,7 +177,7 @@ LogicalResult emitc::VariableOp::verify() {
   if (getValueAttr().isa<emitc::OpaqueAttr>())
     return success();
 
-  TypedAttr value = getValueAttr();
+  auto value = cast<TypedAttr>(getValueAttr());
   Type type = getType();
   if (!value.getType().isa<NoneType>() && type != value.getType())
     return emitOpError() << "requires attribute's type (" << value.getType()

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index b07bdc91803ee..39c9e5e1725a8 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -492,9 +492,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
       llvm::dbgs() << "\n");
 
   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
-  auto comparator = [&](DeviceMappingAttrInterface a,
-                        DeviceMappingAttrInterface b) -> bool {
-    return a.getMappingId() < b.getMappingId();
+  auto comparator = [&](Attribute a, Attribute b) -> bool {
+    return cast<DeviceMappingAttrInterface>(a).getMappingId() <
+           cast<DeviceMappingAttrInterface>(b).getMappingId();
   };
   SmallVector<int64_t> forallMappingSizes =
       getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f9265b43eb379..4c8722c29022b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -389,10 +389,7 @@ class RegionBuilderHelper {
     OpBuilder builder = getBuilder();
     Location loc = builder.getUnknownLoc();
     Attribute valueAttr = parseAttribute(value, builder.getContext());
-    Type type = NoneType::get(builder.getContext());
-    if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
-      type = typedAttr.getType();
-    return builder.create<arith::ConstantOp>(loc, type, valueAttr);
+    return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
   }
 
   Value index(int64_t dim) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 10552ca3406c8..6f9b60843d6d5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -165,7 +165,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
         staticStridesVector));
   }
 
-  Operation *clonedOp = clone(b, producer, resultTypes, clonedShapes);
+  LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
 
   // Shift all IndexOp results by the tile offset.
   SmallVector<OpFoldResult> allIvs = llvm::to_vector(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 8fbf8b247a99a..d8ecc807ea051 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -38,7 +38,7 @@ using namespace linalg;
 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
                                               ArrayRef<int64_t> tiledLoopDims) {
   // Get the consumer operand indexing map.
-  LinalgOp consumerOp = consumerOperand->getOwner();
+  auto consumerOp = cast<LinalgOp>(consumerOperand->getOwner());
   AffineMap indexingMap = consumerOp.getMatchingIndexingMap(consumerOperand);
 
   // Search the slice dimensions tiled by a tile loop dimension.
@@ -65,7 +65,7 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
 static SmallVector<int64_t>
 getTiledProducerLoops(OpResult producerResult,
                       ArrayRef<int64_t> tiledSliceDimIndices) {
-  LinalgOp producerOp = producerResult.getOwner();
+  auto producerOp = cast<LinalgOp>(producerResult.getOwner());
 
   // Get the indexing map of the `producerOp` output operand that matches
   // ´producerResult´.
@@ -137,7 +137,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
   b.setInsertionPointAfter(sliceOp);
 
   // Get the producer.
-  LinalgOp producerOp = producerResult.getOwner();
+  auto producerOp = cast<LinalgOp>(producerResult.getOwner());
   Location loc = producerOp.getLoc();
 
   // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
@@ -345,7 +345,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
     return failure();
 
   // Check `sliceOp` and `consumerOp` are in the same block.
-  LinalgOp consumerOp = consumerOpOperand->getOwner();
+  auto consumerOp = cast<LinalgOp>(consumerOpOperand->getOwner());
   if (sliceOp->getBlock() != rootOp->getBlock() ||
       consumerOp->getBlock() != rootOp->getBlock())
     return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index adc1769bb6053..23c831f0a018a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -198,7 +198,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
          "expected the number of loops and induction variables to match");
   // Replace the index operations in the body of the innermost loop op.
   if (!loopOps.empty()) {
-    LoopLikeOpInterface loopOp = loopOps.back();
+    auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
     for (IndexOp indexOp :
          llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
       rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index d0e865271a39a..b4d95b70de839 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -66,7 +66,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
 
   Operation *reductionOp = combinerOps[0];
-  std::optional<Attribute> identity = getNeutralElement(reductionOp);
+  std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
   if (!identity.has_value())
     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
 
@@ -272,9 +272,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
 
-  SmallVector<Attribute> neutralElements;
+  SmallVector<TypedAttr> neutralElements;
   for (Operation *reductionOp : combinerOps) {
-    std::optional<Attribute> neutralElement = getNeutralElement(reductionOp);
+    std::optional<TypedAttr> neutralElement = getNeutralElement(reductionOp);
     if (!neutralElement.has_value())
       return b.notifyMatchFailure(op, "cannot find neutral element.");
     neutralElements.push_back(*neutralElement);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c5026eed4423e..1c3745f66cbf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -271,7 +271,7 @@ struct LinalgOpPartialReductionInterface
       return op->emitOpError("Failed to anaysis the reduction operation.");
 
     Operation *reductionOp = combinerOps[0];
-    std::optional<Attribute> identity = getNeutralElement(reductionOp);
+    std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
     if (!identity.has_value())
       return op->emitOpError(
           "Failed to get an identity value for the reduction operation.");
@@ -328,8 +328,8 @@ struct LinalgOpPartialReductionInterface
 
     // Step 1: Extract a slice of the input operands.
     SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
-    SmallVector<Value, 4> tiledOperands =
-        makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true);
+    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+        b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
 
     // Step 2: Extract the accumulator operands
     SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dbda4356c81f0..166f42637523f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -83,11 +83,8 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
     return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
   }
   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
-  Type paddingType = rewriter.getType<NoneType>();
-  if (auto typedAttr = paddingAttr.dyn_cast<TypedAttr>())
-    paddingType = typedAttr.getType();
   Value paddingValue = rewriter.create<arith::ConstantOp>(
-      opToPad.getLoc(), paddingType, paddingAttr);
+      opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
 
   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
   OpOperand *currOpOperand = opOperand;
@@ -576,7 +573,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
           rewriter, loc, operand, innerPackSizes, innerPos,
           /*outerDimsPerm=*/{});
       // TODO: value of the padding attribute should be determined by consumers.
-      Attribute zeroAttr =
+      auto zeroAttr =
           rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
       Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
       packOps.push_back(rewriter.create<tensor::PackOp>(

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index f4969baa10e39..5e3413accf7c3 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -985,7 +985,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
 }
 
 /// Return the identity numeric value associated to the give op.
-std::optional<Attribute> getNeutralElement(Operation *op) {
+std::optional<TypedAttr> getNeutralElement(Operation *op) {
   // Builder only used as helper for attribute creation.
   OpBuilder b(op->getContext());
   Type resultType = op->getResult(0).getType();

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 8f0f8dab77e22..6d286a31290e6 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1245,7 +1245,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
   floatTy = broadcast(floatTy, shape);
   intTy = broadcast(intTy, shape);
 
-  auto bconst = [&](Attribute attr) -> Value {
+  auto bconst = [&](TypedAttr attr) -> Value {
     Value value = b.create<arith::ConstantOp>(attr);
     return broadcast(b, value, shape);
   };

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index de0766daeff30..38fb11348f285 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -111,9 +111,9 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
               loc, rewriter.getIndexType(), size);
         sizes[i] = size;
       } else {
-        sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
-        size =
-            rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
+        auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
+        size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
+        sizes[i] = sizeAttr;
       }
       strides[i] = stride;
       if (i > 0)

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 0d4bff8aa6e09..3cd4937e96f26 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -44,7 +44,7 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
   if (auto intTy = type.dyn_cast<IntegerType>())
     return IntegerAttr::get(intTy, sizedValue);
 
-  return SplatElementsAttr::get(type, sizedValue);
+  return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
 }
 
 Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index cbf591372f9af..3a488b311b95a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -213,7 +213,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
   return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
 }
 
-mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
+mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
   if (tp.isa<FloatType>())
     return builder.getFloatAttr(tp, 1.0);
   if (tp.isa<IndexType>())

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 47c581c8d88ca..b6e6def4e5860 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -79,7 +79,7 @@ Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
 /// for unsupported types we raise `llvm_unreachable` rather than
 /// returning a null attribute.
-Attribute getOneAttr(Builder &builder, Type tp);
+TypedAttr getOneAttr(Builder &builder, Type tp);
 
 /// Generates the comparison `v != 0` where `v` is of numeric type.
 /// For floating types, we use the "unordered" comparator (i.e., returns

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e214820c2f47f..c8cf85e695749 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1729,7 +1729,7 @@ class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
     auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
     if (!splat)
       return failure();
-    Attribute newAttr = splat.getSplatValue<Attribute>();
+    TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
     if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
       newAttr = DenseElementsAttr::get(vecDstType, newAttr);
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
@@ -1767,9 +1767,9 @@ class ExtractOpNonSplatConstantFolder final
     copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
     int64_t elemBeginPosition =
         linearize(completePositions, computeStrides(vecTy.getShape()));
-    auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
+    auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
 
-    Attribute newAttr;
+    TypedAttr newAttr;
     if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
       SmallVector<Attribute> elementValues(
           denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index e318d4dc15915..3f26558237a2f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -191,7 +191,7 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
 private:
   LogicalResult matchAndRewrite(MaskOp maskOp,
                                 PatternRewriter &rewriter) const final {
-    MaskableOpInterface maskableOp = maskOp.getMaskableOp();
+    auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
     SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
     if (!sourceOp)
       return failure();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6592930b5f66f..2b5706aaa7748 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -692,8 +692,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned operandIndex = yieldOperand->getOperandNumber();
     Attribute scalarAttr = dense.getSplatValue<Attribute>();
-    Attribute newAttr = DenseElementsAttr::get(
-        warpOp.getResult(operandIndex).getType(), scalarAttr);
+    auto newAttr = DenseElementsAttr::get(
+        cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
     Location loc = warpOp.getLoc();
     rewriter.setInsertionPointAfter(warpOp);
     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 1aee27560ea35..e806db7b7cdef 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -79,7 +79,7 @@ struct MaskCompressOpConversion
       src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
                                                op.getConstantSrcAttr());
     } else {
-      Attribute zeroAttr = rewriter.getZeroAttr(opType);
+      auto zeroAttr = rewriter.getZeroAttr(opType);
       src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
     }
 

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 9203943123470..7943655aa1b89 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -315,7 +315,7 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
   return getArrayAttr(attrs);
 }
 
-Attribute Builder::getZeroAttr(Type type) {
+TypedAttr Builder::getZeroAttr(Type type) {
   if (type.isa<FloatType>())
     return getFloatAttr(type, 0.0);
   if (type.isa<IndexType>())

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 7e95ca137fdd1..0810d8965b385 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -539,7 +539,7 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
                                             elementType.getContext());
 
   // Wrap AffineMap into Attribute.
-  Attribute layout = AffineMapAttr::get(map);
+  auto layout = AffineMapAttr::get(map);
 
   // Drop default memory space value and replace it with empty attribute.
   memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -559,7 +559,7 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                                             elementType.getContext());
 
   // Wrap AffineMap into Attribute.
-  Attribute layout = AffineMapAttr::get(map);
+  auto layout = AffineMapAttr::get(map);
 
   // Drop default memory space value and replace it with empty attribute.
   memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -577,7 +577,7 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
                                             elementType.getContext());
 
   // Wrap AffineMap into Attribute.
-  Attribute layout = AffineMapAttr::get(map);
+  auto layout = AffineMapAttr::get(map);
 
   // Convert deprecated integer-like memory space to Attribute.
   Attribute memorySpace =
@@ -598,7 +598,7 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                                             elementType.getContext());
 
   // Wrap AffineMap into Attribute.
-  Attribute layout = AffineMapAttr::get(map);
+  auto layout = AffineMapAttr::get(map);
 
   // Convert deprecated integer-like memory space to Attribute.
   Attribute memorySpace =

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 8b0581d8be04c..4c3713fa9a75b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -535,7 +535,7 @@ std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
 
 spirv::SpecConstantOp
 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
-                                        Attribute defaultValue) {
+                                        TypedAttr defaultValue) {
   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
   auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
                                                     defaultValue);

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 9d41166b0fb7a..487b66769390e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -220,7 +220,7 @@ class Deserializer {
 
   /// Creates a spirv::SpecConstantOp.
   spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
-                                           Attribute defaultValue);
+                                           TypedAttr defaultValue);
 
   /// Processes the OpVariable instructions at current `offset` into `binary`.
   /// It is expected that this method is used for variables that are to be

diff  --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 0c34ccb47257e..46126b0c5f400 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -222,7 +222,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
                  Block::iterator inlinePoint, IRMapping &mapper,
                  ValueRange resultsToReplace, TypeRange regionResultTypes,
                  std::optional<Location> inlineLoc,
-                 bool shouldCloneInlinedRegion, Operation *call = nullptr) {
+                 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
   assert(resultsToReplace.size() == regionResultTypes.size());
   // We expect the region to have at least one block.
   if (src->empty())
@@ -328,7 +328,7 @@ static LogicalResult
 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
                  Block::iterator inlinePoint, ValueRange inlinedOperands,
                  ValueRange resultsToReplace, std::optional<Location> inlineLoc,
-                 bool shouldCloneInlinedRegion, Operation *call = nullptr) {
+                 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
   // We expect the region to have at least one block.
   if (src->empty())
     return failure();

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index f00febcfa9435..50504988689b0 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -50,7 +50,7 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
         auto *originalOp = info->originalProducer.getOperation();
         auto *originalOpInLinalgOpsVector =
             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
-        *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+        *originalOpInLinalgOpsVector = info->fusedProducer;
         // Don't mark for erasure in the tensor case, let DCE handle this.
         changed = true;
       }


        


More information about the Mlir-commits mailing list