[Mlir-commits] [mlir] f8f4fc1 - [MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 29 19:57:47 PST 2022


Author: liqinweng
Date: 2022-12-30T11:54:06+08:00
New Revision: f8f4fc11d1b9256685a44725b55e431c57aaade3

URL: https://github.com/llvm/llvm-project/commit/f8f4fc11d1b9256685a44725b55e431c57aaade3
DIFF: https://github.com/llvm/llvm-project/commit/f8f4fc11d1b9256685a44725b55e431c57aaade3.diff

LOG: [MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140608

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f6446ea5ae4c..e61169cb64cc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1176,12 +1176,9 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
     getInMutable().assign(lhs.getIn());
     return getResult();
   }
-  Type resType = getType();
-  unsigned bitWidth;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    bitWidth = shapedType.getElementTypeBitWidth();
-  else
-    bitWidth = resType.getIntOrFloatBitWidth();
+
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
         return a.zext(bitWidth);
@@ -1205,12 +1202,9 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
     getInMutable().assign(lhs.getIn());
     return getResult();
   }
-  Type resType = getType();
-  unsigned bitWidth;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    bitWidth = shapedType.getElementTypeBitWidth();
-  else
-    bitWidth = resType.getIntOrFloatBitWidth();
+
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
         return a.sext(bitWidth);
@@ -1259,13 +1253,8 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
     return getResult();
   }
 
-  Type resType = getType();
-  unsigned bitWidth;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    bitWidth = shapedType.getElementTypeBitWidth();
-  else
-    bitWidth = resType.getIntOrFloatBitWidth();
-
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
         return a.trunc(bitWidth);
@@ -1361,12 +1350,7 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resEleType = getElementTypeOrSelf(getType());
   return constFoldCastOp<IntegerAttr, FloatAttr>(
       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
         FloatType floatTy = resEleType.cast<FloatType>();
@@ -1387,12 +1371,7 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resEleType = getElementTypeOrSelf(getType());
   return constFoldCastOp<IntegerAttr, FloatAttr>(
       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
         FloatType floatTy = resEleType.cast<FloatType>();
@@ -1412,17 +1391,12 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<FloatAttr, IntegerAttr>(
-      operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
-        IntegerType intTy = resEleType.cast<IntegerType>();
+      operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
         bool ignored;
-        APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
+        APSInt api(bitWidth, /*isUnsigned=*/true);
         castStatus = APFloat::opInvalidOp !=
                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
         return api;
@@ -1438,17 +1412,12 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<FloatAttr, IntegerAttr>(
-      operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
-        IntegerType intTy = resEleType.cast<IntegerType>();
+      operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
         bool ignored;
-        APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
+        APSInt api(bitWidth, /*isUnsigned=*/false);
         castStatus = APFloat::opInvalidOp !=
                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
         return api;


        


More information about the Mlir-commits mailing list