[Mlir-commits] [mlir] 08b0977 - [mlir][tosa] Add check if the operand of the operations is constant.

Robert Suderman llvmlistbot at llvm.org
Tue Mar 21 14:11:00 PDT 2023


Author: TatWai Chong
Date: 2023-03-21T20:54:47Z
New Revision: 08b0977a1925cf0a2cf6f87fcbf1d656e873f7c5

URL: https://github.com/llvm/llvm-project/commit/08b0977a1925cf0a2cf6f87fcbf1d656e873f7c5
DIFF: https://github.com/llvm/llvm-project/commit/08b0977a1925cf0a2cf6f87fcbf1d656e873f7c5.diff

LOG: [mlir][tosa] Add check if the operand of the operations is constant.

Some uses of TOSA rely on the constant operands of particular operations,
e.g. paddings and pad_const in pad op. Add a verification pattern in the
validation pass, and this is optionally enabled.

Change-Id: I1628c0840a27ab06ef91150eee56ad4f5ac9543d

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D145412

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 7fd2f9ba54f1..1c3bfbebb1cc 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -84,8 +84,12 @@ def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
 
   let options = [
       Option<"profileName", "profile", "std::string",
-      /*default=*/"\"undefined\"",
-      "Validation if ops match for given profile">];
+             /*default=*/"\"undefined\"",
+             "Validate if operations match for the given profile">,
+      Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool",
+             /*default=*/"false",
+             "Verify if the properties of certain operations align the spec requirement">,
+   ];
 }
 
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 896f81e75daa..4cb727b00ca0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -35,17 +35,68 @@ using namespace mlir::tosa;
 
 namespace {
 
+static LogicalResult checkConstantOperandPad(Operation *op) {
+  if (auto pad_op = dyn_cast<tosa::PadOp>(op)) {
+    DenseElementsAttr paddings;
+    if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings)))
+      return op->emitOpError("padding of pad is not constant");
+
+    DenseElementsAttr pad_const;
+    // Assume this op is zero-padding if pad_const is not presented.
+    if (pad_op.getPadConst() &&
+        !matchPattern(pad_op.getPadConst(), m_Constant(&pad_const)))
+      return op->emitOpError("pad_const of pad is not constant");
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandTranspose(Operation *op) {
+  if (auto transpose_op = dyn_cast<tosa::TransposeOp>(op)) {
+    DenseElementsAttr perms;
+    if (!matchPattern(transpose_op.getPerms(), m_Constant(&perms)))
+      return op->emitOpError("perms of transpose is not constant");
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
+  if (auto fc_op = dyn_cast<tosa::FullyConnectedOp>(op)) {
+    DenseElementsAttr weight;
+    if (!matchPattern(fc_op.getWeight(), m_Constant(&weight)))
+      return op->emitOpError("weight of fully_connected is not constant");
+
+    DenseElementsAttr bias;
+    if (!matchPattern(fc_op.getBias(), m_Constant(&bias)))
+      return op->emitOpError("bias of fully_connected is not constant");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Validation Pass.
 //===----------------------------------------------------------------------===//
 
 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 public:
-  explicit TosaValidation() = default;
+  explicit TosaValidation() { populateConstantOperandChecks(); }
+  void runOnOperation() override;
+
+  LogicalResult applyConstantOperandCheck(Operation *op) {
+    for (auto &checker : const_checkers) {
+      if (failed(checker(op)))
+        return failure();
+    }
+    return success();
+  }
 
 private:
-  void runOnOperation() override;
+  void populateConstantOperandChecks() {
+    const_checkers.emplace_back(checkConstantOperandPad);
+    const_checkers.emplace_back(checkConstantOperandTranspose);
+    const_checkers.emplace_back(checkConstantOperandFullyConnected);
+  }
 
+  SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
   std::optional<TosaProfileEnum> profileType;
 };
 
@@ -62,6 +113,10 @@ void TosaValidation::runOnOperation() {
         return signalPassFailure();
       }
     }
+
+    // Some uses of TOSA rely on the constant operands of particular operations.
+    if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
+      signalPassFailure();
   });
 }
 } // namespace

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9f9c6ca6ce64..bb7a3f5287c7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment
 
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
@@ -43,3 +43,48 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
   %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
+  %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
+  %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+  // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
+  %1 = "tosa.pad"(%arg0, %0, %arg1) : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor<i8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
+  // expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
+  %0 = "tosa.transpose"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf32>
+}
+
+// -----
+
+func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
+  %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+  %1 = "tosa.reshape"(%arg0) {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
+  // expected-error at +1 {{'tosa.fully_connected' op weight of fully_connected is not constant}}
+  %2 = "tosa.fully_connected"(%1, %arg1, %0) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
+  return %2 : tensor<273x2xf32>
+}
+
+// -----
+
+func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2xf32>) -> tensor<273x2xf32> {
+  %0 = "tosa.const"() {value = dense<[[-0.613216758, -0.63714242, -0.73500061], [0.180762768, 0.773053169, -0.933686495]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %1 = "tosa.reshape"(%arg0) {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
+  // expected-error at +1 {{'tosa.fully_connected' op bias of fully_connected is not constant}}
+  %2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
+  return %2 : tensor<273x2xf32>
+}


        


More information about the Mlir-commits mailing list