[Mlir-commits] [mlir] a5f0b23 - [mlir][tosa][fix] Add proper type checking trait for tosa mul
Eric Kunze
llvmlistbot at llvm.org
Fri Jul 21 16:39:57 PDT 2023
Author: TatWai Chong
Date: 2023-07-21T23:29:05Z
New Revision: a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f
URL: https://github.com/llvm/llvm-project/commit/a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f
DIFF: https://github.com/llvm/llvm-project/commit/a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f.diff
LOG: [mlir][tosa][fix] Add proper type checking trait for tosa mul
when operating integer type tensors, tosa elementwise multiplication
requires the element type of result to be a 32-bit integer rather
than the same type as inputs.
Change-Id: Ifd3d7ebd879be5c6b2c8e23aa6d7ef41f39c6d41
Reviewed By: mgehre-amd
Differential Revision: https://reviews.llvm.org/D154988
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Tosa/constant-op-fold.mlir
mlir/test/Dialect/Tosa/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 4447247475f9ad..555d9bea18ba4d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -15,7 +15,9 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -35,6 +37,49 @@ namespace tosa {
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
} // namespace tosa
+
+namespace OpTrait {
+namespace tosa {
+
+// This trait verifies if the element type amoung operands and result
+// of multiplication match tosa specification.
+template <typename ConcreteType>
+class MulOperandsAndResultElementType
+ : public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ auto resElemType = getElementTypeOrSelf(op->getResult(0));
+
+ // In cases of floating point type, op requires the same element
+ // type for all operands and result.
+ if (llvm::isa<FloatType>(resElemType))
+ return impl::verifySameOperandsAndResultElementType(op);
+
+ if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+ IntegerType lhsIntType =
+ getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+ IntegerType rhsIntType =
+ getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+ if (lhsIntType != rhsIntType)
+ return op->emitOpError(
+ "requires the same element type for all operands");
+
+ // Though the spec requires the element type of result to be i32, a more
+ // relaxed way is provided at dialect level for easier cooperating with
+ // other dialects.
+ if (lhsIntType.getWidth() > resIntType.getWidth())
+ return op->emitOpError("invalid data type size for operands or result");
+
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+} // namespace tosa
+} // namespace OpTrait
+
} // namespace mlir
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 812db606128a25..3e3c070bac140e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -747,12 +747,17 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
);
}
+def MulOperandsAndResultElementType :
+ NativeOpTrait<"MulOperandsAndResultElementType"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
Commutative,
- SameOperandsAndResultElementType]> {
+ MulOperandsAndResultElementType]> {
let summary = "Multiplication operator";
let description = [{
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1055d6ff6fb784..29d57f2b7f696c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -538,8 +538,10 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
// CHECK-LABEL: @test_simple_i16
func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: linalg.generic
+ // CHECK: arith.extsi
+ // CHECK: arith.extsi
// CHECK: arith.muli
- %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16>
+ %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index ec4d8bd74a5e83..e4762de5d0c8ed 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -294,13 +294,13 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// -----
// CHECK-LABEL: @fold_mul_splat_i8
-func.func @fold_mul_splat_i8() -> tensor<10xi8> {
+func.func @fold_mul_splat_i8() -> tensor<10xi32> {
%one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
%two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
- %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
- // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>}
+ %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+ // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
// CHECK: return %[[THREE]]
- return %mul : tensor<10xi8>
+ return %mul : tensor<10xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 0ad53bd4811ac7..f0ff06a0946981 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -229,6 +229,13 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: mul
+func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
+ %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
More information about the Mlir-commits
mailing list