[Mlir-commits] [mlir] 088f15e - [mlir][tosa] Add folder for tosa.cast
Rob Suderman
llvmlistbot at llvm.org
Mon Aug 29 17:35:32 PDT 2022
Author: Rob Suderman
Date: 2022-08-29T17:21:24-07:00
New Revision: 088f15e346d68da2875d8ee618a05217559c25f2
URL: https://github.com/llvm/llvm-project/commit/088f15e346d68da2875d8ee618a05217559c25f2
DIFF: https://github.com/llvm/llvm-project/commit/088f15e346d68da2875d8ee618a05217559c25f2.diff
LOG: [mlir][tosa] Add folder for tosa.cast
Tosa.cast should fold on splats as it is trivial to fold the operation
into the splatted value.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D132518
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6d27d1e7f404..fe6371170225 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -23,6 +23,7 @@
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -687,6 +688,63 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
if (getInput().getType() == getType())
return getInput();
+
+ auto operand = operands[0].dyn_cast_or_null<ElementsAttr>();
+ if (!operand)
+ return {};
+
+ auto inTy = getInput().getType().cast<ShapedType>();
+ auto outTy = getType().cast<ShapedType>();
+ auto inETy = inTy.getElementType();
+ auto outETy = outTy.getElementType();
+
+ if (operand.isSplat()) {
+ if (inETy.isa<FloatType>() && outETy.isa<FloatType>()) {
+ bool overflow;
+ auto splatVal = operand.getSplatValue<APFloat>();
+ auto &semantics = outETy.cast<FloatType>().getFloatSemantics();
+ splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
+ &overflow);
+ return SplatElementsAttr::get(outTy, splatVal);
+ }
+
+ if (inETy.isa<IntegerType>() && outETy.isa<FloatType>()) {
+ auto unsign = inETy.cast<IntegerType>().isUnsignedInteger();
+ APFloat splatVal(outETy.cast<FloatType>().getFloatSemantics());
+ splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
+ llvm::RoundingMode::NearestTiesToEven);
+ return SplatElementsAttr::get(outTy, splatVal);
+ }
+
+ if (inETy.isa<FloatType>() && outETy.isa<IntegerType>()) {
+ auto unsign = outETy.cast<IntegerType>().isUnsignedInteger();
+ auto intVal =
+ APSInt(outETy.cast<IntegerType>().getIntOrFloatBitWidth(), unsign);
+ auto floatVal = operand.getSplatValue<APFloat>();
+ bool exact;
+ floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
+ return SplatElementsAttr::get(outTy, intVal);
+ }
+
+ if (inETy.isa<IntegerType>() && outETy.isa<IntegerType>()) {
+ auto unsignIn = inETy.cast<IntegerType>().isUnsignedInteger();
+ bool trunc =
+ inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
+ auto intVal = operand.getSplatValue<APInt>();
+ auto bitwidth = outETy.getIntOrFloatBitWidth();
+
+ if (trunc) {
+ intVal = intVal.trunc(bitwidth);
+ } else if (unsignIn) {
+ intVal = intVal.zext(bitwidth);
+ } else {
+ intVal = intVal.sext(bitwidth);
+ }
+
+ return SplatElementsAttr::get(outTy, intVal);
+ }
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index af24d590bfaf..43f2196c834d 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -427,3 +427,58 @@ func.func @slice_singleton() -> tensor<1x1xi32> {
// CHECK: return %[[SLICE]]
return %slice : tensor<1x1xi32>
}
+
+// -----
+
+// CHECK: func.func @cast_float_to_float
+func.func @cast_float_to_float() -> tensor<f16> {
+ %splat = "tosa.const"() {value = dense<42.0> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.200000e+01> : tensor<f16>} : () -> tensor<f16>
+ %cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<f16>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<f16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_float
+func.func @cast_int_to_float() -> tensor<f16> {
+ %splat = "tosa.const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.000000e+00> : tensor<f16>} : () -> tensor<f16>
+ %cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<f16>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<f16>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_int
+func.func @cast_float_to_int() -> tensor<i16> {
+ %splat = "tosa.const"() {value = dense<-4.0> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-4> : tensor<i16>} : () -> tensor<i16>
+ %cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<i16>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_int_trunc
+func.func @cast_int_to_int_trunc() -> tensor<i16> {
+ %splat = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
+ %cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<i16>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<i16>
+}
+
+// -----
+
+// CHECK: func.func @cast_int_to_int_sign
+func.func @cast_int_to_int_sign() -> tensor<i32> {
+ %splat = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ %cast = "tosa.cast"(%splat) : (tensor<i16>) -> tensor<i32>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<i32>
+}
More information about the Mlir-commits
mailing list