[Mlir-commits] [mlir] f8f4fc1 - [MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 29 19:57:47 PST 2022
Author: liqinweng
Date: 2022-12-30T11:54:06+08:00
New Revision: f8f4fc11d1b9256685a44725b55e431c57aaade3
URL: https://github.com/llvm/llvm-project/commit/f8f4fc11d1b9256685a44725b55e431c57aaade3
DIFF: https://github.com/llvm/llvm-project/commit/f8f4fc11d1b9256685a44725b55e431c57aaade3.diff
LOG: [MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D140608
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f6446ea5ae4c..e61169cb64cc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1176,12 +1176,9 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
getInMutable().assign(lhs.getIn());
return getResult();
}
- Type resType = getType();
- unsigned bitWidth;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- bitWidth = shapedType.getElementTypeBitWidth();
- else
- bitWidth = resType.getIntOrFloatBitWidth();
+
+ Type resType = getElementTypeOrSelf(getType());
+ unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
return a.zext(bitWidth);
@@ -1205,12 +1202,9 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
getInMutable().assign(lhs.getIn());
return getResult();
}
- Type resType = getType();
- unsigned bitWidth;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- bitWidth = shapedType.getElementTypeBitWidth();
- else
- bitWidth = resType.getIntOrFloatBitWidth();
+
+ Type resType = getElementTypeOrSelf(getType());
+ unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
return a.sext(bitWidth);
@@ -1259,13 +1253,8 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
return getResult();
}
- Type resType = getType();
- unsigned bitWidth;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- bitWidth = shapedType.getElementTypeBitWidth();
- else
- bitWidth = resType.getIntOrFloatBitWidth();
-
+ Type resType = getElementTypeOrSelf(getType());
+ unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
return a.trunc(bitWidth);
@@ -1361,12 +1350,7 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
- Type resType = getType();
- Type resEleType;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- resEleType = shapedType.getElementType();
- else
- resEleType = resType;
+ Type resEleType = getElementTypeOrSelf(getType());
return constFoldCastOp<IntegerAttr, FloatAttr>(
operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
FloatType floatTy = resEleType.cast<FloatType>();
@@ -1387,12 +1371,7 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
- Type resType = getType();
- Type resEleType;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- resEleType = shapedType.getElementType();
- else
- resEleType = resType;
+ Type resEleType = getElementTypeOrSelf(getType());
return constFoldCastOp<IntegerAttr, FloatAttr>(
operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
FloatType floatTy = resEleType.cast<FloatType>();
@@ -1412,17 +1391,12 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
- Type resType = getType();
- Type resEleType;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- resEleType = shapedType.getElementType();
- else
- resEleType = resType;
+ Type resType = getElementTypeOrSelf(getType());
+ unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<FloatAttr, IntegerAttr>(
- operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
- IntegerType intTy = resEleType.cast<IntegerType>();
+ operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
bool ignored;
- APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
+ APSInt api(bitWidth, /*isUnsigned=*/true);
castStatus = APFloat::opInvalidOp !=
a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
return api;
@@ -1438,17 +1412,12 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
- Type resType = getType();
- Type resEleType;
- if (auto shapedType = resType.dyn_cast<ShapedType>())
- resEleType = shapedType.getElementType();
- else
- resEleType = resType;
+ Type resType = getElementTypeOrSelf(getType());
+ unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<FloatAttr, IntegerAttr>(
- operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
- IntegerType intTy = resEleType.cast<IntegerType>();
+ operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
bool ignored;
- APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
+ APSInt api(bitWidth, /*isUnsigned=*/false);
castStatus = APFloat::opInvalidOp !=
a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
return api;
More information about the Mlir-commits
mailing list