[Mlir-commits] [mlir] c8dcdba - [mlir][tosa] Fix constant folding of tosa.mul

Eric Kunze llvmlistbot at llvm.org
Thu May 18 16:45:33 PDT 2023


Author: Spenser Bauman
Date: 2023-05-18T16:42:06-07:00
New Revision: c8dcdba125193581c5af44bb52c33826ea8625e3

URL: https://github.com/llvm/llvm-project/commit/c8dcdba125193581c5af44bb52c33826ea8625e3
DIFF: https://github.com/llvm/llvm-project/commit/c8dcdba125193581c5af44bb52c33826ea8625e3.diff

LOG: [mlir][tosa] Fix constant folding of tosa.mul

The constant folder for tosa.mul produces a tensor attribute whose type
may not match the result type of the operation when broadcasting is
needed. This results in a tosa.const op whose attribute's type does not
match the type of the const op.
This change explicitly expands the attribute to the expected result
type.

Reviewed By: eric-k256, jpienaar

Differential Revision: https://reviews.llvm.org/D150439

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4827a9ee54c4..b9fa2f80601e7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1846,6 +1846,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
 // Operator: const
 //===----------------------------------------------------------------------===//
 def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
+                                     AllShapesMatch<["value", "output"]>,
                                      FirstAttrDerivedResultType]> {
   let summary = "Constant op.";
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3eb83e015dac0..2477f154bc501 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -648,13 +648,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
-      return lhsAttr;
+      return lhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, lhsAttr, shift))
       return rhs;
   }
   if (lhsTy == resultTy) {
     if (isSplatZero(resultETy, rhsAttr))
-      return rhsAttr;
+      return rhsAttr.resizeSplat(resultTy);
     if (isSplatOne(resultETy, rhsAttr, shift))
       return lhs;
   }

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e9bd1225fe7ca..3379850faf272 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -203,6 +203,19 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   return %1 : tensor<2x3xi32>
 }
 
+// CHECK-LABEL: @mul_zero_broadcast
+func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+  // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> 
+  // CHECK-NOT: tosa.mul
+  %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
+  %1 = "tosa.mul"(%arg0, %zeros) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+
+  // CHECK-NOT: tosa.mul
+  // CHECK: return %[[ZERO]], %[[ZERO]]
+  %2 = "tosa.mul"(%zeros, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
+}
+
 // CHECK-LABEL: @select_same_value
 func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index edb4bb0a873ec..e285a9de1d66d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -143,3 +143,11 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
   return
 }
+
+// -----
+
+func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
+  // expected-error at +1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
+  %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
+  return %0 : tensor<100x100xf32>
+}


        


More information about the Mlir-commits mailing list