[Mlir-commits] [mlir] e1bb203 - [mlir][tosa] Added folders for tosa.mul

Rob Suderman llvmlistbot at llvm.org
Mon Aug 29 16:45:34 PDT 2022


Author: Rob Suderman
Date: 2022-08-29T16:43:50-07:00
New Revision: e1bb203755bc8f797ca4e3b2160732f9ef9b356a

URL: https://github.com/llvm/llvm-project/commit/e1bb203755bc8f797ca4e3b2160732f9ef9b356a
DIFF: https://github.com/llvm/llvm-project/commit/e1bb203755bc8f797ca4e3b2160732f9ef9b356a.diff

LOG: [mlir][tosa] Added folders for tosa.mul

Added folders for tosa.sub that handles bypassing sub-zero,
fold subtraction of two splat tensors.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132678

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ff8c47c3b1a4..29b40cdbcaaa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -737,6 +737,7 @@ def Tosa_MulOp : Tosa_Op<"mul", [
   );
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39c74974ab1f..6e9765571e85 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -506,6 +506,89 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
                                                             lhsTy);
 }
 
+namespace {
+DenseElementsAttr MulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
+                                  RankedTensorType ty, int32_t shift) {
+  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
+    if (ty.getElementType().isa<IntegerType>()) {
+      APInt l = lhs.getSplatValue<APInt>();
+      APInt r = rhs.getSplatValue<APInt>();
+
+      if (shift == 0) {
+        return DenseElementsAttr::get(ty, l * r);
+      }
+
+      auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+      l = l.sext(bitwidth * 2);
+      r = r.sext(bitwidth * 2);
+      auto result = l * r;
+      result.lshrInPlace(shift);
+      result = result.trunc(bitwidth);
+      return DenseElementsAttr::get(ty, result);
+    }
+
+    if (ty.getElementType().isa<FloatType>()) {
+      APFloat l = lhs.getSplatValue<APFloat>();
+      APFloat r = rhs.getSplatValue<APFloat>();
+      APFloat result = l * r;
+      return DenseElementsAttr::get(ty, result);
+    }
+  }
+
+  return {};
+}
+} // namespace
+
+OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+  auto lhs = getInput1();
+  auto rhs = getInput2();
+  auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
+  auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  if (!lhsTy || !rhsTy || !resultTy)
+    return {};
+  if (lhsTy != rhsTy)
+    return {};
+
+  auto resultETy = resultTy.getElementType();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+    auto val = lhsAttr.getSplatValue<APFloat>();
+    if (val.isZero())
+      return lhsAttr;
+    if (val.isExactlyValue(1.0))
+      return rhs;
+  }
+
+  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+    auto val = rhsAttr.getSplatValue<APFloat>();
+    if (val.isZero())
+      return rhsAttr;
+    if (val.isExactlyValue(1.0))
+      return lhs;
+  }
+
+  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+    auto val = lhsAttr.getSplatValue<APInt>();
+    if (val.isZero())
+      return lhsAttr;
+    if (val.getSExtValue() == (1 << getShift()))
+      return rhs;
+  }
+
+  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+    auto val = rhsAttr.getSplatValue<APInt>();
+    if (val.isZero())
+      return rhsAttr;
+    if (val.getSExtValue() == (1 << getShift()))
+      return lhs;
+  }
+
+  return MulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
+}
+
 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
   auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index f023cc3571ee..5187143f9ee0 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -164,6 +164,115 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+
+// CHECK-LABEL: @fold_mul_zero_rhs_f32
+func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
+  %mul = "tosa.mul"(%arg0, %zero) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_lhs_f32
+func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
+  %mul = "tosa.mul"(%zero, %arg0) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_rhs_i32
+func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
+  %mul = "tosa.mul"(%arg0, %zero) {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_zero_lhs_i32
+func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
+  %mul = "tosa.mul"(%zero, %arg0) {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %[[ZERO]]
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_rhs_f32
+func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+  %mul = "tosa.mul"(%arg0, %one) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_lhs_f32
+func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+  %mul = "tosa.mul"(%one, %arg0) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %mul : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_rhs_i32
+func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
+  %mul = "tosa.mul"(%arg0, %one) {shift = 6 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_one_lhs_i32
+func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
+  %mul = "tosa.mul"(%one, %arg0) {shift = 6 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %mul : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_splat_i8
+func.func @fold_mul_splat_i8() -> tensor<10xi8> {
+  %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
+  %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
+  %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<68> : tensor<10xi8>}
+  // CHECK: return %[[THREE]]
+  return %mul : tensor<10xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_splat_f32
+func.func @fold_mul_splat_f32() -> tensor<10xf32> {
+  %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %mul = "tosa.mul"(%one, %two) {shift = 0 : i32} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<6.000000e+00> : tensor<10xf32>}
+  // CHECK: return %[[THREE]]
+  return %mul : tensor<10xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_sub_zero_rhs_f32
 func.func @fold_sub_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>


        


More information about the Mlir-commits mailing list