[Mlir-commits] [mlir] 605fc89 - [mlir][Arithmetic] Add common constant folder function for type cast ops.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 12 19:12:26 PDT 2022
Author: jacquesguan
Date: 2022-04-13T02:11:59Z
New Revision: 605fc89a613e0a2215de35b0705ebd09a8fa5e1d
URL: https://github.com/llvm/llvm-project/commit/605fc89a613e0a2215de35b0705ebd09a8fa5e1d
DIFF: https://github.com/llvm/llvm-project/commit/605fc89a613e0a2215de35b0705ebd09a8fa5e1d.diff
LOG: [mlir][Arithmetic] Add common constant folder function for type cast ops.
This revision replaces current type cast constant folder with a new common type cast constant folder function template.
It will cover all former folder and support fold the constant splat and vector.
Differential Revision: https://reviews.llvm.org/D123489
Added:
Modified:
mlir/include/mlir/Dialect/CommonFolders.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 7ba43c92e7563..d503bb02403ae 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -108,6 +108,56 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
return {};
}
+template <
+ class AttrElementT, class TargetAttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class TargetElementValueT = typename TargetAttrElementT::ValueType,
+ class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
+Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
+ const CalculationT &calculate) {
+ assert(operands.size() == 1 && "Cast op takes one operand");
+ if (!operands[0])
+ return {};
+
+ if (operands[0].isa<AttrElementT>()) {
+ auto op = operands[0].cast<AttrElementT>();
+ bool castStatus = true;
+ auto res = calculate(op.getValue(), castStatus);
+ if (!castStatus)
+ return {};
+ return TargetAttrElementT::get(resType, res);
+ }
+ if (operands[0].isa<SplatElementsAttr>()) {
+ // The operand is a splat so we can avoid expanding the values out and
+ // just fold based on the splat value.
+ auto op = operands[0].cast<SplatElementsAttr>();
+ bool castStatus = true;
+ auto elementResult =
+ calculate(op.getSplatValue<ElementValueT>(), castStatus);
+ if (!castStatus)
+ return {};
+ return DenseElementsAttr::get(resType, elementResult);
+ }
+ if (operands[0].isa<ElementsAttr>()) {
+ // Operand is ElementsAttr-derived; perform an element-wise fold by
+ // expanding the value.
+ auto op = operands[0].cast<ElementsAttr>();
+ bool castStatus = true;
+ auto opIt = op.value_begin<ElementValueT>();
+ SmallVector<TargetElementValueT> elementResults;
+ elementResults.reserve(op.getNumElements());
+ for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
+ auto elt = calculate(*opIt, castStatus);
+ if (!castStatus)
+ return {};
+ elementResults.push_back(elt);
+ }
+
+ return DenseElementsAttr::get(resType, elementResults);
+ }
+ return {};
+}
+
} // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 50ff5581c216c..1fa4b1b8032a2 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -875,16 +875,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
- if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
- return IntegerAttr::get(
- getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
-
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
getInMutable().assign(lhs.getIn());
return getResult();
}
-
- return {};
+ Type resType = getType();
+ unsigned bitWidth;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ bitWidth = shapedType.getElementTypeBitWidth();
+ else
+ bitWidth = resType.getIntOrFloatBitWidth();
+ return constFoldCastOp<IntegerAttr, IntegerAttr>(
+ operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+ return a.zext(bitWidth);
+ });
}
bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -900,16 +904,20 @@ LogicalResult arith::ExtUIOp::verify() {
//===----------------------------------------------------------------------===//
OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
- if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
- return IntegerAttr::get(
- getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
-
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
getInMutable().assign(lhs.getIn());
return getResult();
}
-
- return {};
+ Type resType = getType();
+ unsigned bitWidth;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ bitWidth = shapedType.getElementTypeBitWidth();
+ else
+ bitWidth = resType.getIntOrFloatBitWidth();
+ return constFoldCastOp<IntegerAttr, IntegerAttr>(
+ operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+ return a.sext(bitWidth);
+ });
}
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -954,15 +962,17 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
return getResult();
}
- if (!operands[0])
- return {};
+ Type resType = getType();
+ unsigned bitWidth;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ bitWidth = shapedType.getElementTypeBitWidth();
+ else
+ bitWidth = resType.getIntOrFloatBitWidth();
- if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
- return IntegerAttr::get(
- getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
- }
-
- return {};
+ return constFoldCastOp<IntegerAttr, IntegerAttr>(
+ operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+ return a.trunc(bitWidth);
+ });
}
bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1048,15 +1058,21 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
- if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
- const APInt &api = lhs.getValue();
- FloatType floatTy = getType().cast<FloatType>();
- APFloat apf(floatTy.getFloatSemantics(),
- APInt::getZero(floatTy.getWidth()));
- apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
- return FloatAttr::get(floatTy, apf);
- }
- return {};
+ Type resType = getType();
+ Type resEleType;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ resEleType = shapedType.getElementType();
+ else
+ resEleType = resType;
+ return constFoldCastOp<IntegerAttr, FloatAttr>(
+ operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+ FloatType floatTy = resEleType.cast<FloatType>();
+ APFloat apf(floatTy.getFloatSemantics(),
+ APInt::getZero(floatTy.getWidth()));
+ apf.convertFromAPInt(a, /*IsSigned=*/false,
+ APFloat::rmNearestTiesToEven);
+ return apf;
+ });
}
//===----------------------------------------------------------------------===//
@@ -1068,15 +1084,21 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
- if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
- const APInt &api = lhs.getValue();
- FloatType floatTy = getType().cast<FloatType>();
- APFloat apf(floatTy.getFloatSemantics(),
- APInt::getZero(floatTy.getWidth()));
- apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
- return FloatAttr::get(floatTy, apf);
- }
- return {};
+ Type resType = getType();
+ Type resEleType;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ resEleType = shapedType.getElementType();
+ else
+ resEleType = resType;
+ return constFoldCastOp<IntegerAttr, FloatAttr>(
+ operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+ FloatType floatTy = resEleType.cast<FloatType>();
+ APFloat apf(floatTy.getFloatSemantics(),
+ APInt::getZero(floatTy.getWidth()));
+ apf.convertFromAPInt(a, /*IsSigned=*/true,
+ APFloat::rmNearestTiesToEven);
+ return apf;
+ });
}
//===----------------------------------------------------------------------===//
// FPToUIOp
@@ -1087,21 +1109,21 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
- if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
- const APFloat &apf = lhs.getValue();
- IntegerType intTy = getType().cast<IntegerType>();
- bool ignored;
- APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
- if (APFloat::opInvalidOp ==
- apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
- // Undefined behavior invoked - the destination type can't represent
- // the input constant.
- return {};
- }
- return IntegerAttr::get(getType(), api);
- }
-
- return {};
+ Type resType = getType();
+ Type resEleType;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ resEleType = shapedType.getElementType();
+ else
+ resEleType = resType;
+ return constFoldCastOp<FloatAttr, IntegerAttr>(
+ operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
+ IntegerType intTy = resEleType.cast<IntegerType>();
+ bool ignored;
+ APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
+ castStatus = APFloat::opInvalidOp !=
+ a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
+ return api;
+ });
}
//===----------------------------------------------------------------------===//
@@ -1113,21 +1135,21 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
- if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
- const APFloat &apf = lhs.getValue();
- IntegerType intTy = getType().cast<IntegerType>();
- bool ignored;
- APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
- if (APFloat::opInvalidOp ==
- apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
- // Undefined behavior invoked - the destination type can't represent
- // the input constant.
- return {};
- }
- return IntegerAttr::get(getType(), api);
- }
-
- return {};
+ Type resType = getType();
+ Type resEleType;
+ if (auto shapedType = resType.dyn_cast<ShapedType>())
+ resEleType = shapedType.getElementType();
+ else
+ resEleType = resType;
+ return constFoldCastOp<FloatAttr, IntegerAttr>(
+ operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
+ IntegerType intTy = resEleType.cast<IntegerType>();
+ bool ignored;
+ APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
+ castStatus = APFloat::opInvalidOp !=
+ a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
+ return api;
+ });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index b4c92d6089e9b..e20725b84d12a 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -282,6 +282,53 @@ func @signExtendConstant() -> i16 {
return %ext : i16
}
+// CHECK-LABEL: @signExtendConstantSplat
+// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi16>
+// CHECK: return %[[cres]]
+func @signExtendConstantSplat() -> vector<4xi16> {
+ %c-2 = arith.constant -2 : i8
+ %splat = vector.splat %c-2 : vector<4xi8>
+ %ext = arith.extsi %splat : vector<4xi8> to vector<4xi16>
+ return %ext : vector<4xi16>
+}
+
+// CHECK-LABEL: @signExtendConstantVector
+// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
+// CHECK: return %[[cres]]
+func @signExtendConstantVector() -> vector<4xi16> {
+ %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
+ %ext = arith.extsi %vector : vector<4xi8> to vector<4xi16>
+ return %ext : vector<4xi16>
+}
+
+// CHECK-LABEL: @unsignedExtendConstant
+// CHECK: %[[cres:.+]] = arith.constant 2 : i16
+// CHECK: return %[[cres]]
+func @unsignedExtendConstant() -> i16 {
+ %c2 = arith.constant 2 : i8
+ %ext = arith.extui %c2 : i8 to i16
+ return %ext : i16
+}
+
+// CHECK-LABEL: @unsignedExtendConstantSplat
+// CHECK: %[[cres:.+]] = arith.constant dense<2> : vector<4xi16>
+// CHECK: return %[[cres]]
+func @unsignedExtendConstantSplat() -> vector<4xi16> {
+ %c2 = arith.constant 2 : i8
+ %splat = vector.splat %c2 : vector<4xi8>
+ %ext = arith.extui %splat : vector<4xi8> to vector<4xi16>
+ return %ext : vector<4xi16>
+}
+
+// CHECK-LABEL: @unsignedExtendConstantVector
+// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
+// CHECK: return %[[cres]]
+func @unsignedExtendConstantVector() -> vector<4xi16> {
+ %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
+ %ext = arith.extui %vector : vector<4xi8> to vector<4xi16>
+ return %ext : vector<4xi16>
+}
+
// CHECK-LABEL: @truncConstant
// CHECK: %[[cres:.+]] = arith.constant -2 : i16
// CHECK: return %[[cres]]
@@ -291,6 +338,25 @@ func @truncConstant(%arg0: i8) -> i16 {
return %tr : i16
}
+// CHECK-LABEL: @truncConstantSplat
+// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8>
+// CHECK: return %[[cres]]
+func @truncConstantSplat() -> vector<4xi8> {
+ %c-2 = arith.constant -2 : i16
+ %splat = vector.splat %c-2 : vector<4xi16>
+ %trunc = arith.trunci %splat : vector<4xi16> to vector<4xi8>
+ return %trunc : vector<4xi8>
+}
+
+// CHECK-LABEL: @truncConstantVector
+// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
+// CHECK: return %[[cres]]
+func @truncConstantVector() -> vector<4xi8> {
+ %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
+ %trunc = arith.trunci %vector : vector<4xi16> to vector<4xi8>
+ return %trunc : vector<4xi8>
+}
+
// CHECK-LABEL: @truncTrunc
// CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8
// CHECK: return %[[cres]]
@@ -921,6 +987,25 @@ func @constant_FPtoUI() -> i32 {
return %res : i32
}
+// CHECK-LABEL: @constant_FPtoUI_splat(
+func @constant_FPtoUI_splat() -> vector<4xi32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<2> : vector<4xi32>
+ // CHECK: return %[[C0]]
+ %c0 = arith.constant 2.0 : f32
+ %splat = vector.splat %c0 : vector<4xf32>
+ %res = arith.fptoui %splat : vector<4xf32> to vector<4xi32>
+ return %res : vector<4xi32>
+}
+
+// CHECK-LABEL: @constant_FPtoUI_vector(
+func @constant_FPtoUI_vector() -> vector<4xi32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+ // CHECK: return %[[C0]]
+ %vector = arith.constant dense<[1.0, 3.0, 5.0, 7.0]> : vector<4xf32>
+ %res = arith.fptoui %vector : vector<4xf32> to vector<4xi32>
+ return %res : vector<4xi32>
+}
+
// -----
// CHECK-LABEL: @invalid_constant_FPtoUI(
func @invalid_constant_FPtoUI() -> i32 {
@@ -942,6 +1027,25 @@ func @constant_FPtoSI() -> i32 {
return %res : i32
}
+// CHECK-LABEL: @constant_FPtoSI_splat(
+func @constant_FPtoSI_splat() -> vector<4xi32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<-2> : vector<4xi32>
+ // CHECK: return %[[C0]]
+ %c0 = arith.constant -2.0 : f32
+ %splat = vector.splat %c0 : vector<4xf32>
+ %res = arith.fptosi %splat : vector<4xf32> to vector<4xi32>
+ return %res : vector<4xi32>
+}
+
+// CHECK-LABEL: @constant_FPtoSI_vector(
+func @constant_FPtoSI_vector() -> vector<4xi32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<[-1, -3, -5, -7]> : vector<4xi32>
+ // CHECK: return %[[C0]]
+ %vector = arith.constant dense<[-1.0, -3.0, -5.0, -7.0]> : vector<4xf32>
+ %res = arith.fptosi %vector : vector<4xf32> to vector<4xi32>
+ return %res : vector<4xi32>
+}
+
// -----
// CHECK-LABEL: @invalid_constant_FPtoSI(
func @invalid_constant_FPtoSI() -> i8 {
@@ -962,16 +1066,54 @@ func @constant_SItoFP() -> f32 {
return %res : f32
}
+// CHECK-LABEL: @constant_SItoFP_splat(
+func @constant_SItoFP_splat() -> vector<4xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+ // CHECK: return %[[C0]]
+ %c0 = arith.constant 2 : i32
+ %splat = vector.splat %c0 : vector<4xi32>
+ %res = arith.sitofp %splat : vector<4xi32> to vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @constant_SItoFP_vector(
+func @constant_SItoFP_vector() -> vector<4xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32>
+ // CHECK: return %[[C0]]
+ %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+ %res = arith.sitofp %vector : vector<4xi32> to vector<4xf32>
+ return %res : vector<4xf32>
+}
+
// -----
// CHECK-LABEL: @constant_UItoFP(
func @constant_UItoFP() -> f32 {
// CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32
// CHECK: return %[[C0]]
%c0 = arith.constant 2 : i32
- %res = arith.sitofp %c0 : i32 to f32
+ %res = arith.uitofp %c0 : i32 to f32
return %res : f32
}
+// CHECK-LABEL: @constant_UItoFP_splat(
+func @constant_UItoFP_splat() -> vector<4xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+ // CHECK: return %[[C0]]
+ %c0 = arith.constant 2 : i32
+ %splat = vector.splat %c0 : vector<4xi32>
+ %res = arith.uitofp %splat : vector<4xi32> to vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @constant_UItoFP_vector(
+func @constant_UItoFP_vector() -> vector<4xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32>
+ // CHECK: return %[[C0]]
+ %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+ %res = arith.uitofp %vector : vector<4xi32> to vector<4xf32>
+ return %res : vector<4xf32>
+}
+
// -----
// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll
More information about the Mlir-commits
mailing list