[Mlir-commits] [mlir] 3fbc6fd - [TOSA] Loosen folding restrictions for tosa.add, tosa.sub, tosa.mul

Robert Suderman llvmlistbot at llvm.org
Thu Mar 30 11:22:39 PDT 2023


Author: SJW
Date: 2023-03-30T18:22:20Z
New Revision: 3fbc6fd4931f91003a5441866b674d3d635d8a60

URL: https://github.com/llvm/llvm-project/commit/3fbc6fd4931f91003a5441866b674d3d635d8a60
DIFF: https://github.com/llvm/llvm-project/commit/3fbc6fd4931f91003a5441866b674d3d635d8a60.diff

LOG: [TOSA] Loosen folding restrictions for tosa.add,tosa.sub, tosa.mul

Allow folding of different tensor types when the constant tensor is broadcast.
Removed redundant and incorrect AddZero and MulOne canonical optimizations.

Reviewed By: rsuderman, eric-k256

Differential Revision: https://reviews.llvm.org/D145738

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 33741c00d8495..043098f65a9ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -477,7 +477,6 @@ def Tosa_AddOp : Tosa_Op<"add", [
     Tosa_Tensor:$output
   );
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
@@ -796,7 +795,6 @@ def Tosa_MulOp : Tosa_Op<"mul", [
     Tosa_Tensor:$output
   );
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index ef93e1955b60b..19a80c783c475 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -246,92 +246,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
 }
 
-struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::AddOp op,
-                                PatternRewriter &rewriter) const override {
-    auto input1 = op.getInput1();
-    auto input2 = op.getInput2();
-
-    DenseElementsAttr input1Attr;
-    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
-        input2.getType() == op.getType()) {
-      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
-          input1Attr.getSplatValue<APInt>().isZero()) {
-        rewriter.replaceOp(op, op.getInput2());
-        return success();
-      }
-    }
-
-    DenseElementsAttr input2Attr;
-    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
-        input1.getType() == op.getType()) {
-      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
-          input2Attr.getSplatValue<APInt>().isZero()) {
-        rewriter.replaceOp(op, op.getInput1());
-        return success();
-      }
-    }
-
-    return failure();
-  }
-};
-
-void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<AddZeroOptimization>(context);
-}
-
-struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::MulOp op,
-                                PatternRewriter &rewriter) const override {
-    auto input1 = op.getInput1();
-    auto input2 = op.getInput2();
-
-    DenseElementsAttr input1Attr;
-    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
-        input2.getType() == op.getType()) {
-      if (input1Attr.getType().getElementType().isa<FloatType>() &&
-          input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
-        rewriter.replaceOp(op, op.getInput2());
-        return success();
-      }
-
-      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
-          matchPattern(input1, m_One())) {
-        rewriter.replaceOp(op, op.getInput2());
-        return success();
-      }
-    }
-
-    DenseElementsAttr input2Attr;
-    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
-        input1.getType() == op.getType()) {
-      if (input2Attr.getType().getElementType().isa<FloatType>() &&
-          input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
-        rewriter.replaceOp(op, op.getInput1());
-        return success();
-      }
-
-      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
-          matchPattern(input2, m_One())) {
-        rewriter.replaceOp(op, op.getInput1());
-        return success();
-      }
-    }
-
-    return failure();
-  }
-};
-
-void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<MulOneOptimization>(context);
-}
-
 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -609,44 +523,47 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
   return {};
 }
 
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+  if (elemType.isa<FloatType>())
+    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+  if (elemType.isa<IntegerType>())
+    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+  return false;
+}
+
+static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
+  if (elemType.isa<FloatType>())
+    return val && val.isSplat() &&
+           val.getSplatValue<APFloat>().isExactlyValue(1.0);
+  if (elemType.isa<IntegerType>()) {
+    const int64_t shifted = 1LL << shift;
+    return val && val.isSplat() &&
+           val.getSplatValue<APInt>().getSExtValue() == shifted;
+  }
+  return false;
+}
+
 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
   auto resultTy = getType().dyn_cast<RankedTensorType>();
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
-  if (lhsTy != rhsTy)
-    return {};
 
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
   auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
-  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
-    if (lhsAttr.getSplatValue<APFloat>().isZero())
-      return getInput2();
-  }
-
-  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
-    if (rhsAttr.getSplatValue<APFloat>().isZero())
-      return getInput1();
-  }
-
-  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
-    if (lhsAttr.getSplatValue<APInt>().isZero())
-      return getInput2();
-  }
-
-  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
-    if (rhsAttr.getSplatValue<APInt>().isZero())
-      return getInput1();
-  }
+  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
+    return getInput1();
+  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
+    return getInput2();
 
   if (!lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
-                                                            lhsTy);
+                                                            resultTy);
 }
 
 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
@@ -724,50 +641,26 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   auto resultTy = getType().dyn_cast<RankedTensorType>();
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
-  if (lhsTy != rhsTy)
-    return {};
 
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
   auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
-  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
-    auto val = lhsAttr.getSplatValue<APFloat>();
-    if (val.isZero())
+  const int64_t shift = resultETy.isa<IntegerType>() ? getShift() : 0;
+  if (rhsTy == resultTy) {
+    if (isSplatZero(resultETy, lhsAttr))
       return lhsAttr;
-    if (val.isExactlyValue(1.0))
+    if (isSplatOne(resultETy, lhsAttr, shift))
       return rhs;
   }
-
-  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
-    auto val = rhsAttr.getSplatValue<APFloat>();
-    if (val.isZero())
-      return rhsAttr;
-    if (val.isExactlyValue(1.0))
-      return lhs;
-  }
-
-  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
-    auto val = lhsAttr.getSplatValue<APInt>();
-    if (val.isZero())
-      return lhsAttr;
-    const int64_t shift = getShift();
-    const int64_t shifted = 1LL << shift;
-    if (val.getSExtValue() == shifted)
-      return rhs;
-  }
-
-  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
-    auto val = rhsAttr.getSplatValue<APInt>();
-    const int64_t shift = getShift();
-    const int64_t shifted = 1LL << shift;
-    if (val.isZero())
+  if (lhsTy == resultTy) {
+    if (isSplatZero(resultETy, rhsAttr))
       return rhsAttr;
-    if (val.getSExtValue() == shifted)
+    if (isSplatOne(resultETy, rhsAttr, shift))
       return lhs;
   }
 
-  return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
+  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
 }
 
 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
@@ -776,28 +669,19 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   auto resultTy = getType().dyn_cast<RankedTensorType>();
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
-  if (lhsTy != rhsTy)
-    return {};
 
   auto resultETy = resultTy.getElementType();
   auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
   auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
-  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
-    if (rhsAttr.getSplatValue<APFloat>().isZero())
-      return getInput1();
-  }
-
-  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
-    if (rhsAttr.getSplatValue<APInt>().isZero())
-      return getInput1();
-  }
+  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
+    return getInput1();
 
   if (!lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
-                                                              lhsTy);
+                                                              resultTy);
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 77627d8c8ba62..bdd4021cb39a1 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -7,15 +7,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   return %0 : tensor<?x1xf32>
 }
 
-// CHECK-LABEL: @add_zero_
diff erent_shape
-func.func @add_zero_
diff erent_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
-  // CHECK: tosa.add
-  %zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32>
-  %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32>
+// CHECK-LABEL: @add_bcast_zero_int
+func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
+  // CHECK-NOT: tosa.add
+  // CHECK: return %arg0
+  %zeros = "tosa.const"() {value = dense<0> : tensor<1x1x1xi32>} : () -> tensor<1x1x1xi32>
+  %1 = "tosa.add"(%arg0, %zeros) : (tensor<4x2x3xi32>, tensor<1x1x1xi32>) -> tensor<4x2x3xi32>
   return %1 : tensor<4x2x3xi32>
 }
 
-
 // CHECK-LABEL: @add_zero_int
 func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
@@ -176,14 +176,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
   return %1 : tensor<?x?xi32>
 }
 
-// CHECK-LABEL: @mul_one_
diff erent_shape
-func.func @mul_one_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
-  // CHECK: tosa.mul
-  %ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
-  %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
-  return %1 : tensor<4x2x3xf32>
-}
-
 // CHECK-LABEL: @mul_one_float
 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
@@ -193,6 +185,15 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   return %1 : tensor<2x3xf32>
 }
 
+// CHECK-LABEL: @mul_bcast_one_float
+func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.mul
+  %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
+  %1 = "tosa.mul"(%ones, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  return %1 : tensor<2x3xf32>
+}
+
 // CHECK-LABEL: @mul_one_int
 func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list