[Mlir-commits] [mlir] 605fc89 - [mlir][Arithmetic] Add common constant folder function for type cast ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 12 19:12:26 PDT 2022


Author: jacquesguan
Date: 2022-04-13T02:11:59Z
New Revision: 605fc89a613e0a2215de35b0705ebd09a8fa5e1d

URL: https://github.com/llvm/llvm-project/commit/605fc89a613e0a2215de35b0705ebd09a8fa5e1d
DIFF: https://github.com/llvm/llvm-project/commit/605fc89a613e0a2215de35b0705ebd09a8fa5e1d.diff

LOG: [mlir][Arithmetic] Add common constant folder function for type cast ops.

This revision replaces current type cast constant folder with a new common type cast constant folder function template.
It will cover all former folder and support fold the constant splat and vector.

Differential Revision: https://reviews.llvm.org/D123489

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 7ba43c92e7563..d503bb02403ae 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -108,6 +108,56 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
   return {};
 }
 
+template <
+    class AttrElementT, class TargetAttrElementT,
+    class ElementValueT = typename AttrElementT::ValueType,
+    class TargetElementValueT = typename TargetAttrElementT::ValueType,
+    class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
+Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
+                          const CalculationT &calculate) {
+  assert(operands.size() == 1 && "Cast op takes one operand");
+  if (!operands[0])
+    return {};
+
+  if (operands[0].isa<AttrElementT>()) {
+    auto op = operands[0].cast<AttrElementT>();
+    bool castStatus = true;
+    auto res = calculate(op.getValue(), castStatus);
+    if (!castStatus)
+      return {};
+    return TargetAttrElementT::get(resType, res);
+  }
+  if (operands[0].isa<SplatElementsAttr>()) {
+    // The operand is a splat so we can avoid expanding the values out and
+    // just fold based on the splat value.
+    auto op = operands[0].cast<SplatElementsAttr>();
+    bool castStatus = true;
+    auto elementResult =
+        calculate(op.getSplatValue<ElementValueT>(), castStatus);
+    if (!castStatus)
+      return {};
+    return DenseElementsAttr::get(resType, elementResult);
+  }
+  if (operands[0].isa<ElementsAttr>()) {
+    // Operand is ElementsAttr-derived; perform an element-wise fold by
+    // expanding the value.
+    auto op = operands[0].cast<ElementsAttr>();
+    bool castStatus = true;
+    auto opIt = op.value_begin<ElementValueT>();
+    SmallVector<TargetElementValueT> elementResults;
+    elementResults.reserve(op.getNumElements());
+    for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
+      auto elt = calculate(*opIt, castStatus);
+      if (!castStatus)
+        return {};
+      elementResults.push_back(elt);
+    }
+
+    return DenseElementsAttr::get(resType, elementResults);
+  }
+  return {};
+}
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_COMMONFOLDERS_H

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 50ff5581c216c..1fa4b1b8032a2 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -875,16 +875,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
-  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
-    return IntegerAttr::get(
-        getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
-
   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
     getInMutable().assign(lhs.getIn());
     return getResult();
   }
-
-  return {};
+  Type resType = getType();
+  unsigned bitWidth;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    bitWidth = shapedType.getElementTypeBitWidth();
+  else
+    bitWidth = resType.getIntOrFloatBitWidth();
+  return constFoldCastOp<IntegerAttr, IntegerAttr>(
+      operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+        return a.zext(bitWidth);
+      });
 }
 
 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -900,16 +904,20 @@ LogicalResult arith::ExtUIOp::verify() {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
-  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
-    return IntegerAttr::get(
-        getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
-
   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
     getInMutable().assign(lhs.getIn());
     return getResult();
   }
-
-  return {};
+  Type resType = getType();
+  unsigned bitWidth;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    bitWidth = shapedType.getElementTypeBitWidth();
+  else
+    bitWidth = resType.getIntOrFloatBitWidth();
+  return constFoldCastOp<IntegerAttr, IntegerAttr>(
+      operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+        return a.sext(bitWidth);
+      });
 }
 
 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -954,15 +962,17 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
     return getResult();
   }
 
-  if (!operands[0])
-    return {};
+  Type resType = getType();
+  unsigned bitWidth;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    bitWidth = shapedType.getElementTypeBitWidth();
+  else
+    bitWidth = resType.getIntOrFloatBitWidth();
 
-  if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
-    return IntegerAttr::get(
-        getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
-  }
-
-  return {};
+  return constFoldCastOp<IntegerAttr, IntegerAttr>(
+      operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+        return a.trunc(bitWidth);
+      });
 }
 
 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1048,15 +1058,21 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
-  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
-    const APInt &api = lhs.getValue();
-    FloatType floatTy = getType().cast<FloatType>();
-    APFloat apf(floatTy.getFloatSemantics(),
-                APInt::getZero(floatTy.getWidth()));
-    apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
-    return FloatAttr::get(floatTy, apf);
-  }
-  return {};
+  Type resType = getType();
+  Type resEleType;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    resEleType = shapedType.getElementType();
+  else
+    resEleType = resType;
+  return constFoldCastOp<IntegerAttr, FloatAttr>(
+      operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+        FloatType floatTy = resEleType.cast<FloatType>();
+        APFloat apf(floatTy.getFloatSemantics(),
+                    APInt::getZero(floatTy.getWidth()));
+        apf.convertFromAPInt(a, /*IsSigned=*/false,
+                             APFloat::rmNearestTiesToEven);
+        return apf;
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1068,15 +1084,21 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
-  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
-    const APInt &api = lhs.getValue();
-    FloatType floatTy = getType().cast<FloatType>();
-    APFloat apf(floatTy.getFloatSemantics(),
-                APInt::getZero(floatTy.getWidth()));
-    apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
-    return FloatAttr::get(floatTy, apf);
-  }
-  return {};
+  Type resType = getType();
+  Type resEleType;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    resEleType = shapedType.getElementType();
+  else
+    resEleType = resType;
+  return constFoldCastOp<IntegerAttr, FloatAttr>(
+      operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+        FloatType floatTy = resEleType.cast<FloatType>();
+        APFloat apf(floatTy.getFloatSemantics(),
+                    APInt::getZero(floatTy.getWidth()));
+        apf.convertFromAPInt(a, /*IsSigned=*/true,
+                             APFloat::rmNearestTiesToEven);
+        return apf;
+      });
 }
 //===----------------------------------------------------------------------===//
 // FPToUIOp
@@ -1087,21 +1109,21 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
-  if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
-    const APFloat &apf = lhs.getValue();
-    IntegerType intTy = getType().cast<IntegerType>();
-    bool ignored;
-    APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
-    if (APFloat::opInvalidOp ==
-        apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
-      // Undefined behavior invoked - the destination type can't represent
-      // the input constant.
-      return {};
-    }
-    return IntegerAttr::get(getType(), api);
-  }
-
-  return {};
+  Type resType = getType();
+  Type resEleType;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    resEleType = shapedType.getElementType();
+  else
+    resEleType = resType;
+  return constFoldCastOp<FloatAttr, IntegerAttr>(
+      operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
+        IntegerType intTy = resEleType.cast<IntegerType>();
+        bool ignored;
+        APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
+        castStatus = APFloat::opInvalidOp !=
+                     a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
+        return api;
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1113,21 +1135,21 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
-  if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
-    const APFloat &apf = lhs.getValue();
-    IntegerType intTy = getType().cast<IntegerType>();
-    bool ignored;
-    APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
-    if (APFloat::opInvalidOp ==
-        apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
-      // Undefined behavior invoked - the destination type can't represent
-      // the input constant.
-      return {};
-    }
-    return IntegerAttr::get(getType(), api);
-  }
-
-  return {};
+  Type resType = getType();
+  Type resEleType;
+  if (auto shapedType = resType.dyn_cast<ShapedType>())
+    resEleType = shapedType.getElementType();
+  else
+    resEleType = resType;
+  return constFoldCastOp<FloatAttr, IntegerAttr>(
+      operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
+        IntegerType intTy = resEleType.cast<IntegerType>();
+        bool ignored;
+        APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
+        castStatus = APFloat::opInvalidOp !=
+                     a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
+        return api;
+      });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index b4c92d6089e9b..e20725b84d12a 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -282,6 +282,53 @@ func @signExtendConstant() -> i16 {
   return %ext : i16
 }
 
+// CHECK-LABEL: @signExtendConstantSplat
+//       CHECK:   %[[cres:.+]] = arith.constant dense<-2> : vector<4xi16>
+//       CHECK:   return %[[cres]]
+func @signExtendConstantSplat() -> vector<4xi16> {
+  %c-2 = arith.constant -2 : i8
+  %splat = vector.splat %c-2 : vector<4xi8>
+  %ext = arith.extsi %splat : vector<4xi8> to vector<4xi16>
+  return %ext : vector<4xi16>
+}
+
+// CHECK-LABEL: @signExtendConstantVector
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
+//       CHECK:   return %[[cres]]
+func @signExtendConstantVector() -> vector<4xi16> {
+  %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
+  %ext = arith.extsi %vector : vector<4xi8> to vector<4xi16>
+  return %ext : vector<4xi16>
+}
+
+// CHECK-LABEL: @unsignedExtendConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 2 : i16
+//       CHECK:   return %[[cres]]
+func @unsignedExtendConstant() -> i16 {
+  %c2 = arith.constant 2 : i8
+  %ext = arith.extui %c2 : i8 to i16
+  return %ext : i16
+}
+
+// CHECK-LABEL: @unsignedExtendConstantSplat
+//       CHECK:   %[[cres:.+]] = arith.constant dense<2> : vector<4xi16>
+//       CHECK:   return %[[cres]]
+func @unsignedExtendConstantSplat() -> vector<4xi16> {
+  %c2 = arith.constant 2 : i8
+  %splat = vector.splat %c2 : vector<4xi8>
+  %ext = arith.extui %splat : vector<4xi8> to vector<4xi16>
+  return %ext : vector<4xi16>
+}
+
+// CHECK-LABEL: @unsignedExtendConstantVector
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
+//       CHECK:   return %[[cres]]
+func @unsignedExtendConstantVector() -> vector<4xi16> {
+  %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
+  %ext = arith.extui %vector : vector<4xi8> to vector<4xi16>
+  return %ext : vector<4xi16>
+}
+
 // CHECK-LABEL: @truncConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]
@@ -291,6 +338,25 @@ func @truncConstant(%arg0: i8) -> i16 {
   return %tr : i16
 }
 
+// CHECK-LABEL: @truncConstantSplat
+//       CHECK:   %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8>
+//       CHECK:   return %[[cres]]
+func @truncConstantSplat() -> vector<4xi8> {
+  %c-2 = arith.constant -2 : i16
+  %splat = vector.splat %c-2 : vector<4xi16>
+  %trunc = arith.trunci %splat : vector<4xi16> to vector<4xi8>
+  return %trunc : vector<4xi8>
+}
+
+// CHECK-LABEL: @truncConstantVector
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
+//       CHECK:   return %[[cres]]
+func @truncConstantVector() -> vector<4xi8> {
+  %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
+  %trunc = arith.trunci %vector : vector<4xi16> to vector<4xi8>
+  return %trunc : vector<4xi8>
+}
+
 // CHECK-LABEL: @truncTrunc
 //       CHECK:   %[[cres:.+]] = arith.trunci %arg0 : i64 to i8
 //       CHECK:   return %[[cres]]
@@ -921,6 +987,25 @@ func @constant_FPtoUI() -> i32 {
   return %res : i32
 }
 
+// CHECK-LABEL: @constant_FPtoUI_splat(
+func @constant_FPtoUI_splat() -> vector<4xi32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<2> : vector<4xi32>
+  // CHECK: return %[[C0]]
+  %c0 = arith.constant 2.0 : f32
+  %splat = vector.splat %c0 : vector<4xf32>
+  %res = arith.fptoui %splat : vector<4xf32> to vector<4xi32>
+  return %res : vector<4xi32>
+}
+
+// CHECK-LABEL: @constant_FPtoUI_vector(
+func @constant_FPtoUI_vector() -> vector<4xi32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  // CHECK: return %[[C0]]
+  %vector = arith.constant dense<[1.0, 3.0, 5.0, 7.0]> : vector<4xf32>
+  %res = arith.fptoui %vector : vector<4xf32> to vector<4xi32>
+  return %res : vector<4xi32>
+}
+
 // -----
 // CHECK-LABEL: @invalid_constant_FPtoUI(
 func @invalid_constant_FPtoUI() -> i32 {
@@ -942,6 +1027,25 @@ func @constant_FPtoSI() -> i32 {
   return %res : i32
 }
 
+// CHECK-LABEL: @constant_FPtoSI_splat(
+func @constant_FPtoSI_splat() -> vector<4xi32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<-2> : vector<4xi32>
+  // CHECK: return %[[C0]]
+  %c0 = arith.constant -2.0 : f32
+  %splat = vector.splat %c0 : vector<4xf32>
+  %res = arith.fptosi %splat : vector<4xf32> to vector<4xi32>
+  return %res : vector<4xi32>
+}
+
+// CHECK-LABEL: @constant_FPtoSI_vector(
+func @constant_FPtoSI_vector() -> vector<4xi32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<[-1, -3, -5, -7]> : vector<4xi32>
+  // CHECK: return %[[C0]]
+  %vector = arith.constant dense<[-1.0, -3.0, -5.0, -7.0]> : vector<4xf32>
+  %res = arith.fptosi %vector : vector<4xf32> to vector<4xi32>
+  return %res : vector<4xi32>
+}
+
 // -----
 // CHECK-LABEL: @invalid_constant_FPtoSI(
 func @invalid_constant_FPtoSI() -> i8 {
@@ -962,16 +1066,54 @@ func @constant_SItoFP() -> f32 {
   return %res : f32
 }
 
+// CHECK-LABEL: @constant_SItoFP_splat(
+func @constant_SItoFP_splat() -> vector<4xf32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+  // CHECK: return %[[C0]]
+  %c0 = arith.constant 2 : i32
+  %splat = vector.splat %c0 : vector<4xi32>
+  %res = arith.sitofp %splat : vector<4xi32> to vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @constant_SItoFP_vector(
+func @constant_SItoFP_vector() -> vector<4xf32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32>
+  // CHECK: return %[[C0]]
+  %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %res = arith.sitofp %vector : vector<4xi32> to vector<4xf32>
+  return %res : vector<4xf32>
+}
+
 // -----
 // CHECK-LABEL: @constant_UItoFP(
 func @constant_UItoFP() -> f32 {
   // CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32
   // CHECK: return %[[C0]]
   %c0 = arith.constant 2 : i32
-  %res = arith.sitofp %c0 : i32 to f32
+  %res = arith.uitofp %c0 : i32 to f32
   return %res : f32
 }
 
+// CHECK-LABEL: @constant_UItoFP_splat(
+func @constant_UItoFP_splat() -> vector<4xf32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+  // CHECK: return %[[C0]]
+  %c0 = arith.constant 2 : i32
+  %splat = vector.splat %c0 : vector<4xi32>
+  %res = arith.uitofp %splat : vector<4xi32> to vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @constant_UItoFP_vector(
+func @constant_UItoFP_vector() -> vector<4xf32> {
+  // CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32>
+  // CHECK: return %[[C0]]
+  %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %res = arith.uitofp %vector : vector<4xi32> to vector<4xf32>
+  return %res : vector<4xf32>
+}
+
 // -----
 
 // Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll


        


More information about the Mlir-commits mailing list