[Mlir-commits] [mlir] c8dcdba - [mlir][tosa] Fix constant folding of tosa.mul
Eric Kunze
llvmlistbot at llvm.org
Thu May 18 16:45:33 PDT 2023
Author: Spenser Bauman
Date: 2023-05-18T16:42:06-07:00
New Revision: c8dcdba125193581c5af44bb52c33826ea8625e3
URL: https://github.com/llvm/llvm-project/commit/c8dcdba125193581c5af44bb52c33826ea8625e3
DIFF: https://github.com/llvm/llvm-project/commit/c8dcdba125193581c5af44bb52c33826ea8625e3.diff
LOG: [mlir][tosa] Fix constant folding of tosa.mul
The constant folder for tosa.mul produces a tensor attribute whose type
may not match the result type of the operation when broadcasting is
needed. This results in a tosa.const op whose attribute's type does not
match the type of the const op.
This change explicitly expands the attribute to the expected result
type.
Reviewed By: eric-k256, jpienaar
Differential Revision: https://reviews.llvm.org/D150439
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4827a9ee54c4..b9fa2f80601e7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1846,6 +1846,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
// Operator: const
//===----------------------------------------------------------------------===//
def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
+ AllShapesMatch<["value", "output"]>,
FirstAttrDerivedResultType]> {
let summary = "Constant op.";
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3eb83e015dac0..2477f154bc501 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -648,13 +648,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
- return lhsAttr;
+ return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
if (isSplatZero(resultETy, rhsAttr))
- return rhsAttr;
+ return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e9bd1225fe7ca..3379850faf272 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -203,6 +203,19 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
return %1 : tensor<2x3xi32>
}
+// CHECK-LABEL: @mul_zero_broadcast
+func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+ // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}> : () -> tensor<2x3xf32>
+ // CHECK-NOT: tosa.mul
+ %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
+ %1 = "tosa.mul"(%arg0, %zeros) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+
+ // CHECK-NOT: tosa.mul
+ // CHECK: return %[[ZERO]], %[[ZERO]]
+ %2 = "tosa.mul"(%zeros, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
+}
+
// CHECK-LABEL: @select_same_value
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index edb4bb0a873ec..e285a9de1d66d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -143,3 +143,11 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
return
}
+
+// -----
+
+func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
+ // expected-error at +1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
+ %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
+ return %0 : tensor<100x100xf32>
+}
More information about the Mlir-commits
mailing list