[Mlir-commits] [mlir] [MemRef] Migrate away from PointerUnion::{is, get} (NFC) (PR #120382)

Kazu Hirata llvmlistbot at llvm.org
Wed Dec 18 00:33:03 PST 2024


https://github.com/kazutakahirata created https://github.com/llvm/llvm-project/pull/120382

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


>From 8ffd5cc5e3d9c1b74f62765fdbf5dbe1c02aa188 Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Tue, 17 Dec 2024 09:09:24 -0800
Subject: [PATCH] [MemRef] Migrate away from PointerUnion::{is,get} (NFC)

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp       | 10 +++++-----
 .../Transforms/VectorTransferOpTransforms.cpp  | 18 +++++++++---------
 .../Vector/Transforms/VectorTransforms.cpp     |  4 ++--
 3 files changed, 16 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad709813c6216a..491b5f44b722b1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -327,8 +327,8 @@ SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
   SmallVector<int64_t> ints;
   llvm::transform(
       foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
-        assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
-        return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
+        assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
+        return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
       });
   return ints;
 }
@@ -346,7 +346,7 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
                               loc, cast<IntegerAttr>(attr).getInt())
                           .getResult();
 
-                    return foldResult.get<Value>();
+                    return cast<Value>(foldResult);
                   });
   return values;
 }
@@ -1353,8 +1353,8 @@ LogicalResult vector::ExtractOp::verify() {
     return emitOpError(
         "expected position attribute of rank no greater than vector rank");
   for (auto [idx, pos] : llvm::enumerate(position)) {
-    if (pos.is<Attribute>()) {
-      int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+    if (auto attr = dyn_cast<Attribute>(pos)) {
+      int64_t constIdx = cast<IntegerAttr>(attr).getInt();
       if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index bd5f06a3b46d42..b0892d16969d29 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -251,7 +251,7 @@ static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
       continue;
     }
 
-    auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
+    auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
     if (value == 1)
       continue;
     reducedShape.push_back(value.getSExtValue());
@@ -570,8 +570,8 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
   collapsedOffset = affine::makeComposedFoldedAffineApply(
       rewriter, loc, collapsedExpr, collapsedVals);
 
-  if (collapsedOffset.is<Value>()) {
-    indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
+  if (auto value = dyn_cast<Value>(collapsedOffset)) {
+    indicesAfterCollapsing.push_back(value);
   } else {
     indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
         loc, *getConstantIntValue(collapsedOffset)));
@@ -841,8 +841,8 @@ class RewriteScalarExtractElementOfTransferRead
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
           rewriter, loc, sym0 + sym1,
           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
-      if (ofr.is<Value>()) {
-        newIndices[newIndices.size() - 1] = ofr.get<Value>();
+      if (auto value = dyn_cast<Value>(ofr)) {
+        newIndices[newIndices.size() - 1] = value;
       } else {
         newIndices[newIndices.size() - 1] =
             rewriter.create<arith::ConstantIndexOp>(loc,
@@ -879,14 +879,14 @@ class RewriteScalarExtractOfTransferRead
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
-      assert(pos.is<Attribute>() && "Unexpected non-constant index");
-      int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+      assert(isa<Attribute>(pos) && "Unexpected non-constant index");
+      int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
           rewriter, extractOp.getLoc(),
           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
-      if (ofr.is<Value>()) {
-        newIndices[idx] = ofr.get<Value>();
+      if (auto value = dyn_cast<Value>(ofr)) {
+        newIndices[idx] = value;
       } else {
         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
             extractOp.getLoc(), *getConstantIntValue(ofr));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 20cd9cba6909a6..21ec718efd6a7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -598,9 +598,9 @@ struct BubbleDownVectorBitCastForExtract
 
     // Get the first element of the mixed position as integer.
     auto mixedPos = extractOp.getMixedPosition();
-    if (mixedPos.size() > 0 && !mixedPos[0].is<Attribute>())
+    if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
       return failure();
-    uint64_t index = cast<IntegerAttr>(mixedPos[0].get<Attribute>()).getInt();
+    uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
 
     // Get the single scalar (as a vector) in the source value that packs the
     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>



More information about the Mlir-commits mailing list