[Mlir-commits] [mlir] fffd966 - [mlir][tosa] Added folders for tosa.greater
Rob Suderman
llvmlistbot at llvm.org
Mon Aug 29 13:24:40 PDT 2022
Author: Rob Suderman
Date: 2022-08-29T13:20:01-07:00
New Revision: fffd966f0b0f6aff879e2fd60cfc75336beb226b
URL: https://github.com/llvm/llvm-project/commit/fffd966f0b0f6aff879e2fd60cfc75336beb226b
DIFF: https://github.com/llvm/llvm-project/commit/fffd966f0b0f6aff879e2fd60cfc75336beb226b.diff
LOG: [mlir][tosa] Added folders for tosa.greater
Added folders for tosa.greater fold splat values.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D132707
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 620ee39bd139a..c256451cc190f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1161,6 +1161,8 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [
let results = (outs
I1Tensor:$output
);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 99477093b75ac..200d81260bbf5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -441,20 +441,25 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
template <typename IntFolder, typename FloatFolder>
DenseElementsAttr BinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
- RankedTensorType ty) {
+ RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- if (ty.getElementType().isa<IntegerType>()) {
+ auto lETy = lhs.getType().cast<ShapedType>().getElementType();
+ auto rETy = rhs.getType().cast<ShapedType>().getElementType();
+ if (lETy != rETy)
+ return {};
+
+ if (lETy.isa<IntegerType>()) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
- APInt result = IntFolder()(l, r);
- return DenseElementsAttr::get(ty, result);
+ auto result = IntFolder()(l, r);
+ return DenseElementsAttr::get(returnTy, result);
}
- if (ty.getElementType().isa<FloatType>()) {
+ if (lETy.isa<FloatType>()) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
- APFloat result = FloatFolder()(l, r);
- return DenseElementsAttr::get(ty, result);
+ auto result = FloatFolder()(l, r);
+ return DenseElementsAttr::get(returnTy, result);
}
}
@@ -501,6 +506,37 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
lhsTy);
}
+namespace {
+template <typename Cmp>
+struct ComparisonFold {
+ ComparisonFold() {}
+ APInt operator()(const APInt &l, const APInt &r) {
+ return APInt(1, Cmp()(l, r));
+ }
+
+ APInt operator()(const APFloat &l, const APFloat &r) {
+ return APInt(1, Cmp()(l, r));
+ }
+};
+
+struct APIntFoldGreater {
+ APIntFoldGreater() {}
+ APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
+};
+} // namespace
+
+OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
+ auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+ if (!lhsAttr || !rhsAttr)
+ return {};
+
+ return BinaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+ lhsAttr, rhsAttr, resultTy);
+}
+
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
if (getInput().getType() == getType())
return getInput();
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index f392e4297be99..4c12dafcfe5d8 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -164,6 +164,54 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
// -----
+// CHECK-LABEL: @fold_greater_splat_f32_true
+func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
+ %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+ // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+ // CHECK: return %[[BOOL]]
+ return %add : tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_f32_false
+func.func @fold_greater_splat_f32_false() -> tensor<10xi1> {
+ %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+ // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[BOOL]]
+ return %add : tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32_false
+func.func @fold_greater_splat_i32_false() -> tensor<10xi1> {
+ %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+ %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+ // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[BOOL]]
+ return %add : tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32_true
+func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
+ %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
+ %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+ // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+ // CHECK: return %[[BOOL]]
+ return %add : tensor<10xi1>
+}
+
+// -----
+
// CHECK-LABEL: @slice_splat
func.func @slice_splat() -> tensor<1x1x1xi32> {
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
More information about the Mlir-commits
mailing list