[Mlir-commits] [mlir] a5f0b23 - [mlir][tosa][fix] Add proper type checking trait for tosa mul

Eric Kunze llvmlistbot at llvm.org
Fri Jul 21 16:39:57 PDT 2023


Author: TatWai Chong
Date: 2023-07-21T23:29:05Z
New Revision: a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f

URL: https://github.com/llvm/llvm-project/commit/a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f
DIFF: https://github.com/llvm/llvm-project/commit/a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f.diff

LOG: [mlir][tosa][fix] Add proper type checking trait for tosa mul

when operating integer type tensors, tosa elementwise multiplication
requires the element type of result to be a 32-bit integer rather
than the same type as inputs.

Change-Id: Ifd3d7ebd879be5c6b2c8e23aa6d7ef41f39c6d41

Reviewed By: mgehre-amd

Differential Revision: https://reviews.llvm.org/D154988

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/constant-op-fold.mlir
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 4447247475f9ad..555d9bea18ba4d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -15,7 +15,9 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Traits.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -35,6 +37,49 @@ namespace tosa {
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
 } // namespace tosa
+
+namespace OpTrait {
+namespace tosa {
+
+// This trait verifies if the element type amoung operands and result
+// of multiplication match tosa specification.
+template <typename ConcreteType>
+class MulOperandsAndResultElementType
+    : public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    auto resElemType = getElementTypeOrSelf(op->getResult(0));
+
+    // In cases of floating point type, op requires the same element
+    // type for all operands and result.
+    if (llvm::isa<FloatType>(resElemType))
+      return impl::verifySameOperandsAndResultElementType(op);
+
+    if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+      IntegerType lhsIntType =
+          getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+      IntegerType rhsIntType =
+          getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+      if (lhsIntType != rhsIntType)
+        return op->emitOpError(
+            "requires the same element type for all operands");
+
+      // Though the spec requires the element type of result to be i32, a more
+      // relaxed way is provided at dialect level for easier cooperating with
+      // other dialects.
+      if (lhsIntType.getWidth() > resIntType.getWidth())
+        return op->emitOpError("invalid data type size for operands or result");
+
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+} // namespace tosa
+} // namespace OpTrait
+
 } // namespace mlir
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 812db606128a25..3e3c070bac140e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -747,12 +747,17 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
   );
 }
 
+def MulOperandsAndResultElementType :
+  NativeOpTrait<"MulOperandsAndResultElementType"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: mul
 //===----------------------------------------------------------------------===//
 def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
     Commutative,
-    SameOperandsAndResultElementType]> {
+    MulOperandsAndResultElementType]> {
   let summary = "Multiplication operator";
 
   let description = [{

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1055d6ff6fb784..29d57f2b7f696c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -538,8 +538,10 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
 // CHECK-LABEL: @test_simple_i16
 func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: linalg.generic
+  // CHECK: arith.extsi
+  // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16>
+  %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
 
   return
 }

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index ec4d8bd74a5e83..e4762de5d0c8ed 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -294,13 +294,13 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // -----
 
 // CHECK-LABEL: @fold_mul_splat_i8
-func.func @fold_mul_splat_i8() -> tensor<10xi8> {
+func.func @fold_mul_splat_i8() -> tensor<10xi32> {
   %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
   %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
-  %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>}
+  %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
   // CHECK: return %[[THREE]]
-  return %mul : tensor<10xi8>
+  return %mul : tensor<10xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 0ad53bd4811ac7..f0ff06a0946981 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -229,6 +229,13 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: mul
+func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
+  %0 = "tosa.mul"(%arg0, %arg1)  { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
 // -----
 // CHECK-LABEL: pow
 func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {


        


More information about the Mlir-commits mailing list