[Mlir-commits] [mlir] e1bb203 - [mlir][tosa] Added folders for tosa.mul
Rob Suderman
llvmlistbot at llvm.org
Mon Aug 29 16:45:34 PDT 2022
Author: Rob Suderman
Date: 2022-08-29T16:43:50-07:00
New Revision: e1bb203755bc8f797ca4e3b2160732f9ef9b356a
URL: https://github.com/llvm/llvm-project/commit/e1bb203755bc8f797ca4e3b2160732f9ef9b356a
DIFF: https://github.com/llvm/llvm-project/commit/e1bb203755bc8f797ca4e3b2160732f9ef9b356a.diff
LOG: [mlir][tosa] Added folders for tosa.mul
Added folders for tosa.sub that handles bypassing sub-zero,
fold subtraction of two splat tensors.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D132678
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ff8c47c3b1a4..29b40cdbcaaa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -737,6 +737,7 @@ def Tosa_MulOp : Tosa_Op<"mul", [
);
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39c74974ab1f..6e9765571e85 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -506,6 +506,89 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
lhsTy);
}
+namespace {
+DenseElementsAttr MulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
+ RankedTensorType ty, int32_t shift) {
+ if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
+ if (ty.getElementType().isa<IntegerType>()) {
+ APInt l = lhs.getSplatValue<APInt>();
+ APInt r = rhs.getSplatValue<APInt>();
+
+ if (shift == 0) {
+ return DenseElementsAttr::get(ty, l * r);
+ }
+
+ auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+ l = l.sext(bitwidth * 2);
+ r = r.sext(bitwidth * 2);
+ auto result = l * r;
+ result.lshrInPlace(shift);
+ result = result.trunc(bitwidth);
+ return DenseElementsAttr::get(ty, result);
+ }
+
+ if (ty.getElementType().isa<FloatType>()) {
+ APFloat l = lhs.getSplatValue<APFloat>();
+ APFloat r = rhs.getSplatValue<APFloat>();
+ APFloat result = l * r;
+ return DenseElementsAttr::get(ty, result);
+ }
+ }
+
+ return {};
+}
+} // namespace
+
+OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+ auto lhs = getInput1();
+ auto rhs = getInput2();
+ auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
+ auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
+ auto resultTy = getType().dyn_cast<RankedTensorType>();
+ if (!lhsTy || !rhsTy || !resultTy)
+ return {};
+ if (lhsTy != rhsTy)
+ return {};
+
+ auto resultETy = resultTy.getElementType();
+ auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+ if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+ auto val = lhsAttr.getSplatValue<APFloat>();
+ if (val.isZero())
+ return lhsAttr;
+ if (val.isExactlyValue(1.0))
+ return rhs;
+ }
+
+ if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+ auto val = rhsAttr.getSplatValue<APFloat>();
+ if (val.isZero())
+ return rhsAttr;
+ if (val.isExactlyValue(1.0))
+ return lhs;
+ }
+
+ if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+ auto val = lhsAttr.getSplatValue<APInt>();
+ if (val.isZero())
+ return lhsAttr;
+ if (val.getSExtValue() == (1 << getShift()))
+ return rhs;
+ }
+
+ if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+ auto val = rhsAttr.getSplatValue<APInt>();
+ if (val.isZero())
+ return rhsAttr;
+ if (val.getSExtValue() == (1 << getShift()))
+ return lhs;
+ }
+
+ return MulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
+}
+
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index f023cc3571ee..5187143f9ee0 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -164,6 +164,115 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
// -----
+
+// CHECK-LABEL: @fold_mul_zero_rhs_f32
+func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+ %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
+ %mul = "tosa.mul"(%arg0, %zero) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: return %[[ZERO]]
+ return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_lhs_f32
+func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+ %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
+ %mul = "tosa.mul"(%zero, %arg0) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: return %[[ZERO]]
+ return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_rhs_i32
+func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+ %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
+ %mul = "tosa.mul"(%arg0, %zero) {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // CHECK: return %[[ZERO]]
+ return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_lhs_i32
+func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+ %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
+ %mul = "tosa.mul"(%zero, %arg0) {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // CHECK: return %[[ZERO]]
+ return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_rhs_f32
+func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+ %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+ %mul = "tosa.mul"(%arg0, %one) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: return %arg0
+ return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_lhs_f32
+func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+ %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+ %mul = "tosa.mul"(%one, %arg0) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: return %arg0
+ return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_rhs_i32
+func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+ %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
+ %mul = "tosa.mul"(%arg0, %one) {shift = 6 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // CHECK: return %arg0
+ return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_lhs_i32
+func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+ %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
+ %mul = "tosa.mul"(%one, %arg0) {shift = 6 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // CHECK: return %arg0
+ return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_splat_i8
+func.func @fold_mul_splat_i8() -> tensor<10xi8> {
+ %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
+ %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
+ %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
+ // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<68> : tensor<10xi8>}
+ // CHECK: return %[[THREE]]
+ return %mul : tensor<10xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_splat_f32
+func.func @fold_mul_splat_f32() -> tensor<10xf32> {
+ %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %mul = "tosa.mul"(%one, %two) {shift = 0 : i32} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+ // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<6.000000e+00> : tensor<10xf32>}
+ // CHECK: return %[[THREE]]
+ return %mul : tensor<10xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_sub_zero_rhs_f32
func.func @fold_sub_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
More information about the Mlir-commits
mailing list