[Mlir-commits] [mlir] 98eead8 - [mlir][Value] Add v.getDefiningOp<OpTy>()

Sean Silva llvmlistbot at llvm.org
Mon May 11 12:56:41 PDT 2020


Author: Sean Silva
Date: 2020-05-11T12:55:27-07:00
New Revision: 98eead81868c1ba017cc5d8dbea11285d2eadc4c

URL: https://github.com/llvm/llvm-project/commit/98eead81868c1ba017cc5d8dbea11285d2eadc4c
DIFF: https://github.com/llvm/llvm-project/commit/98eead81868c1ba017cc5d8dbea11285d2eadc4c.diff

LOG: [mlir][Value] Add v.getDefiningOp<OpTy>()

Summary:
This makes a common pattern of
`dyn_cast_or_null<OpTy>(v.getDefiningOp())` more concise.

Differential Revision: https://reviews.llvm.org/D79681

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-3.md
    mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
    mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
    mlir/include/mlir/IR/Value.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
    mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
    mlir/lib/Dialect/Affine/EDSC/Builders.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Quant/IR/QuantOps.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md
index cc31454fe533..d6a72b071647 100644
--- a/mlir/docs/Tutorials/Toy/Ch-3.md
+++ b/mlir/docs/Tutorials/Toy/Ch-3.md
@@ -91,8 +91,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
                   mlir::PatternRewriter &rewriter) const override {
     // Look through the input of the current transpose.
     mlir::Value transposeInput = op.getOperand();
-    TransposeOp transposeInputOp =
-        llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
+    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
 
     // Input defined by another transpose? If not, no match.
     if (!transposeInputOp)

diff  --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
index 8529ea0f24ee..6b789c8d27d1 100644
--- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
@@ -40,8 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
                   mlir::PatternRewriter &rewriter) const override {
     // Look through the input of the current transpose.
     mlir::Value transposeInput = op.getOperand();
-    TransposeOp transposeInputOp =
-        llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
+    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
 
     // Input defined by another transpose? If not, no match.
     if (!transposeInputOp)

diff  --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index 0dd38b2c31a4..c979f2d5fae3 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
                   mlir::PatternRewriter &rewriter) const override {
     // Look through the input of the current transpose.
     mlir::Value transposeInput = op.getOperand();
-    TransposeOp transposeInputOp =
-        llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
+    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
 
     // Input defined by another transpose? If not, no match.
     if (!transposeInputOp)

diff  --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
index 0dd38b2c31a4..c979f2d5fae3 100644
--- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
                   mlir::PatternRewriter &rewriter) const override {
     // Look through the input of the current transpose.
     mlir::Value transposeInput = op.getOperand();
-    TransposeOp transposeInputOp =
-        llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
+    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
 
     // Input defined by another transpose? If not, no match.
     if (!transposeInputOp)

diff  --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
index 0dd38b2c31a4..c979f2d5fae3 100644
--- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
                   mlir::PatternRewriter &rewriter) const override {
     // Look through the input of the current transpose.
     mlir::Value transposeInput = op.getOperand();
-    TransposeOp transposeInputOp =
-        llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
+    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
 
     // Input defined by another transpose? If not, no match.
     if (!transposeInputOp)

diff  --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index fafc3876db27..d48b989578cf 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -63,8 +63,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
                   mlir::PatternRewriter &rewriter) const override {
     // Look through the input of the current transpose.
     mlir::Value transposeInput = op.getOperand();
-    TransposeOp transposeInputOp =
-        llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
+    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
 
     // Input defined by another transpose? If not, no match.
     if (!transposeInputOp)

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 78517309468d..74f504c25156 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -116,6 +116,13 @@ class Value {
   /// defines it.
   Operation *getDefiningOp() const;
 
+  /// If this value is the result of an operation of type OpTy, return the
+  /// operation that defines it.
+  template <typename OpTy>
+  OpTy getDefiningOp() const {
+    return llvm::dyn_cast_or_null<OpTy>(getDefiningOp());
+  }
+
   /// If this value is the result of an operation, use it as a location,
   /// otherwise return an unknown location.
   Location getLoc() const;

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 185be49930b7..5a395937101f 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -453,7 +453,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
       auto symbol = operands[i];
       assert(isValidSymbol(symbol));
       // Check if the symbol is a constant.
-      if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol.getDefiningOp()))
+      if (auto cOp = symbol.getDefiningOp<ConstantIndexOp>())
         dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
                                           cOp.getValue());
     }

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index b43cd6bd7be6..5c3f33d0a693 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -665,7 +665,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
   // Add top level symbol.
   addSymbolId(getNumSymbolIds(), id);
   // Check if the symbol is a constant.
-  if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id.getDefiningOp()))
+  if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
     setIdToConstant(id, constOp.getValue());
 }
 

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 5151569d8067..e6d7127762d5 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -64,7 +64,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
     assert(cst->containsId(value) && "value expected to be present");
     if (isValidSymbol(value)) {
       // Check if the symbol is a constant.
-      if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
+      if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
         cst->setIdToConstant(value, cOp.getValue());
     } else if (auto loop = getForInductionVarOwner(value)) {
       if (failed(cst->addAffineForOpDomain(loop)))

diff  --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index b52c264d8bab..3821b4a2cf34 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -219,7 +219,7 @@ struct LoopToGpuConverter {
 
 // Return true if the value is obviously a constant "one".
 static bool isConstantOne(Value value) {
-  if (auto def = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
+  if (auto def = value.getDefiningOp<ConstantIndexOp>())
     return def.getValue() == 1;
   return false;
 }
@@ -505,11 +505,11 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
 /// `upperBound`.
 static Value deriveStaticUpperBound(Value upperBound,
                                     PatternRewriter &rewriter) {
-  if (auto op = dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp())) {
+  if (auto op = upperBound.getDefiningOp<ConstantIndexOp>()) {
     return op;
   }
 
-  if (auto minOp = dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
+  if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) {
     for (const AffineExpr &result : minOp.map().getResults()) {
       if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) {
         return rewriter.create<ConstantIndexOp>(minOp.getLoc(),
@@ -518,7 +518,7 @@ static Value deriveStaticUpperBound(Value upperBound,
     }
   }
 
-  if (auto multiplyOp = dyn_cast_or_null<MulIOp>(upperBound.getDefiningOp())) {
+  if (auto multiplyOp = upperBound.getDefiningOp<MulIOp>()) {
     if (auto lhs = dyn_cast_or_null<ConstantIndexOp>(
             deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter)
                 .getDefiningOp()))
@@ -607,7 +607,7 @@ static LogicalResult processParallelLoop(
                                   launchIndependent](Value val) -> Value {
     if (launchIndependent(val))
       return val;
-    if (ConstantOp constOp = dyn_cast_or_null<ConstantOp>(val.getDefiningOp()))
+    if (ConstantOp constOp = val.getDefiningOp<ConstantOp>())
       return rewriter.create<ConstantOp>(constOp.getLoc(), constOp.getValue());
     return {};
   };

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 6d9974233a9f..7ee82f9e18bf 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -110,7 +110,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
 LogicalResult
 LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
                                        PatternRewriter &rewriter) const {
-  auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
+  auto subViewOp = loadOp.memref().getDefiningOp<SubViewOp>();
   if (!subViewOp) {
     return failure();
   }
@@ -131,8 +131,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
 LogicalResult
 StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
                                         PatternRewriter &rewriter) const {
-  auto subViewOp =
-      dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
+  auto subViewOp = storeOp.memref().getDefiningOp<SubViewOp>();
   if (!subViewOp) {
     return failure();
   }

diff  --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
index 50e26574b7d5..98e6be955cba 100644
--- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
@@ -93,7 +93,7 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
                             unsigned &numSymbols) {
   AffineExpr d;
   Value resultVal = nullptr;
-  if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val.getDefiningOp())) {
+  if (auto constant = val.getDefiningOp<ConstantIndexOp>()) {
     d = getAffineConstantExpr(constant.getValue(), context);
   } else if (isValidSymbol(val) && !isValidDim(val)) {
     d = getAffineSymbolExpr(numSymbols++, context);

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c6d67723ecd1..16f4a3c6068e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -591,7 +591,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
     // 2. Compose AffineApplyOps and dispatch dims or symbols.
     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
       auto t = operands[i];
-      auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
+      auto affineApply = t.getDefiningOp<AffineApplyOp>();
       if (affineApply) {
         // a. Compose affine.apply operations.
         LLVM_DEBUG(affineApply.getOperation()->print(
@@ -912,7 +912,7 @@ void AffineApplyOp::getCanonicalizationPatterns(
 static LogicalResult foldMemRefCast(Operation *op) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
-    auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+    auto cast = operand.get().getDefiningOp<MemRefCastOp>();
     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
       operand.set(cast.getOperand());
       folded = true;

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 7b58d3a5ca0d..fe669624f6cb 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -965,7 +965,7 @@ static Value vectorizeOperand(Value operand, Operation *op,
     return nullptr;
   }
   // 3. vectorize constant.
-  if (auto constant = dyn_cast_or_null<ConstantOp>(operand.getDefiningOp())) {
+  if (auto constant = operand.getDefiningOp<ConstantOp>()) {
     return vectorizeConstant(
         op, constant,
         VectorType::get(state->strategy->vectorSizes, operand.getType()));

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 83b65798ed9e..3a055d04b962 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -425,9 +425,8 @@ static LogicalResult verify(LandingpadOp op) {
     } else {
       // catch - global addresses only.
       // Bitcast ops should have global addresses as their args.
-      if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) {
-        if (auto addrOp =
-                dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp()))
+      if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
+        if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
           continue;
         return op.emitError("constant clauses expected")
                    .attachNote(bcOp.getLoc())
@@ -435,9 +434,9 @@ static LogicalResult verify(LandingpadOp op) {
                   "bitcast used in clauses for landingpad";
       }
       // NullOp and AddressOfOp allowed
-      if (dyn_cast_or_null<NullOp>(value.getDefiningOp()))
+      if (value.getDefiningOp<NullOp>())
         continue;
-      if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp()))
+      if (value.getDefiningOp<AddressOfOp>())
         continue;
       return op.emitError("clause #")
              << idx << " is not a known constant - null, addressof, bitcast";

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5803824a3162..fc2353e4087e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -52,7 +52,7 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
 static LogicalResult foldMemRefCast(Operation *op) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+    auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
     if (castOp && canFoldIntoConsumerOp(castOp)) {
       operand.set(castOp.getOperand());
       folded = true;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index b85c586633cb..d541ed2a4f2d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -319,8 +319,8 @@ fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
 
     // Must be a subview or a slice to guarantee there are loops we can fuse
     // into.
-    auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
-    auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
+    auto subView = consumedView.getDefiningOp<SubViewOp>();
+    auto slice = consumedView.getDefiningOp<SliceOp>();
     if (!subView && !slice) {
       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
       continue;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 33479717a645..03f8d9e3fd18 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -88,7 +88,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
 /// Otherwise return size.
 static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
                                                  Value size) {
-  auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
+  auto affineMinOp = size.getDefiningOp<AffineMinOp>();
   if (!affineMinOp)
     return size;
   int64_t minConst = std::numeric_limits<int64_t>::max();
@@ -112,7 +112,7 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
     alignment_attr =
         IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
   if (!dynamicBuffers)
-    if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
+    if (auto cst = size.getDefiningOp<ConstantIndexOp>())
       return std_alloc(
           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
           ValueRange{}, alignment_attr);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 1fdbcdcb94fa..462c2ef0c9ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -287,7 +287,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
       // accesses, unless we statically know the subview size divides the view
       // size evenly.
       int64_t viewSize = viewType.getDimSize(r);
-      auto sizeCst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp());
+      auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
       if (ShapedType::isDynamic(viewSize) || !sizeCst ||
           (viewSize % sizeCst.getValue()) != 0) {
         // Compute min(size, dim - offset) to avoid out-of-bounds accesses.

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index f67a9f7fbc22..b0dc1fa10679 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -36,7 +36,7 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
   // Matches x -> [scast -> scast] -> y, replacing the second scast with the
   // value of x if the casts invert each other.
-  auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
+  auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
   if (!srcScastOp || srcScastOp.arg().getType() != getType())
     return OpFoldResult();
   return srcScastOp.arg();

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index d93e1b835529..591179455c94 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -55,7 +55,7 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
 }
 
 static LogicalResult verify(ForOp op) {
-  if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
+  if (auto cst = op.step().getDefiningOp<ConstantIndexOp>())
     if (cst.getValue() <= 0)
       return op.emitOpError("constant step operand must be positive");
 
@@ -403,7 +403,7 @@ static LogicalResult verify(ParallelOp op) {
 
   // Check whether all constant step values are positive.
   for (Value stepValue : stepValues)
-    if (auto cst = dyn_cast_or_null<ConstantIndexOp>(stepValue.getDefiningOp()))
+    if (auto cst = stepValue.getDefiningOp<ConstantIndexOp>())
       if (cst.getValue() <= 0)
         return op.emitOpError("constant step operand must be positive");
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp
index 3c3140c052ee..94dba40a6436 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp
@@ -29,7 +29,7 @@ static void specializeLoopForUnrolling(ParallelOp op) {
   SmallVector<int64_t, 2> constantIndices;
   constantIndices.reserve(op.upperBound().size());
   for (auto bound : op.upperBound()) {
-    auto minOp = dyn_cast_or_null<AffineMinOp>(bound.getDefiningOp());
+    auto minOp = bound.getDefiningOp<AffineMinOp>();
     if (!minOp)
       return;
     int64_t minConstant = std::numeric_limits<int64_t>::max();

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index cb7fd0e6e2ea..553be944ab30 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -209,7 +209,7 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
 static LogicalResult foldMemRefCast(Operation *op) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
-    auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+    auto cast = operand.get().getDefiningOp<MemRefCastOp>();
     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
       operand.set(cast.getOperand());
       folded = true;
@@ -1696,7 +1696,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
 
 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
   // Fold IndexCast(IndexCast(x)) -> x
-  auto cast = dyn_cast_or_null<IndexCastOp>(getOperand().getDefiningOp());
+  auto cast = getOperand().getDefiningOp<IndexCastOp>();
   if (cast && cast.getOperand().getType() == getType())
     return cast.getOperand();
 
@@ -2617,8 +2617,7 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
   auto folds = [](Operation *op) {
     bool folded = false;
     for (OpOperand &operand : op->getOpOperands()) {
-      auto castOp =
-          dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+      auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
       if (castOp && canFoldIntoConsumerOp(castOp)) {
         operand.set(castOp.getOperand());
         folded = true;
@@ -2890,12 +2889,11 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
   LogicalResult matchAndRewrite(ViewOp viewOp,
                                 PatternRewriter &rewriter) const override {
     Value memrefOperand = viewOp.getOperand(0);
-    MemRefCastOp memrefCastOp =
-        dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
+    MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
     if (!memrefCastOp)
       return failure();
     Value allocOperand = memrefCastOp.getOperand();
-    AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
+    AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
     if (!allocOp)
       return failure();
     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index f00f2843bd18..96f8597baa34 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1611,7 +1611,7 @@ class TransposeFolder final : public OpRewritePattern<TransposeOp> {
 
     // Return if the input of 'transposeOp' is not defined by another transpose.
     TransposeOp parentTransposeOp =
-        dyn_cast_or_null<TransposeOp>(transposeOp.vector().getDefiningOp());
+        transposeOp.vector().getDefiningOp<TransposeOp>();
     if (!parentTransposeOp)
       return failure();
 
@@ -1684,7 +1684,7 @@ OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
   // into:
   //    %t = vector.tuple .., %e_i, ..  // one less use
   //    %x = %e_i
-  if (auto tupleOp = dyn_cast_or_null<TupleOp>(getOperand().getDefiningOp()))
+  if (auto tupleOp = getOperand().getDefiningOp<TupleOp>())
     return tupleOp.getOperand(getIndex());
   return {};
 }

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 209ed696c45a..0d1966fcaea5 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -193,12 +193,9 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
-  auto lbCstOp =
-      dyn_cast_or_null<ConstantIndexOp>(forOp.lowerBound().getDefiningOp());
-  auto ubCstOp =
-      dyn_cast_or_null<ConstantIndexOp>(forOp.upperBound().getDefiningOp());
-  auto stepCstOp =
-      dyn_cast_or_null<ConstantIndexOp>(forOp.step().getDefiningOp());
+  auto lbCstOp = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
+  auto ubCstOp = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
+  auto stepCstOp = forOp.step().getDefiningOp<ConstantIndexOp>();
   if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.getValue() < 0 ||
       ubCstOp.getValue() < 0 || stepCstOp.getValue() < 0)
     return failure();
@@ -590,12 +587,9 @@ LogicalResult mlir::loopUnrollByFactor(scf::ForOp forOp,
   Value stepUnrolled;
   bool generateEpilogueLoop = true;
 
-  auto lbCstOp =
-      dyn_cast_or_null<ConstantIndexOp>(forOp.lowerBound().getDefiningOp());
-  auto ubCstOp =
-      dyn_cast_or_null<ConstantIndexOp>(forOp.upperBound().getDefiningOp());
-  auto stepCstOp =
-      dyn_cast_or_null<ConstantIndexOp>(forOp.step().getDefiningOp());
+  auto lbCstOp = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
+  auto ubCstOp = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
+  auto stepCstOp = forOp.step().getDefiningOp<ConstantIndexOp>();
   if (lbCstOp && ubCstOp && stepCstOp) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.getValue();
@@ -1313,12 +1307,11 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
   bool isZeroBased = false;
-  if (auto ubCst =
-          dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp()))
+  if (auto ubCst = lowerBound.getDefiningOp<ConstantIndexOp>())
     isZeroBased = ubCst.getValue() == 0;
 
   bool isStepOne = false;
-  if (auto stepCst = dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp()))
+  if (auto stepCst = step.getDefiningOp<ConstantIndexOp>())
     isStepOne = stepCst.getValue() == 1;
 
   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)


        


More information about the Mlir-commits mailing list