[Mlir-commits] [mlir] [Tosa] Fix TosaValidation for FuncOp (PR #69997)

Tai Ly llvmlistbot at llvm.org
Mon Oct 23 18:37:19 PDT 2023


https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/69997

>From ea174e56c7ad69eaecb81774293859f389944cd6 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 19 Oct 2023 20:48:45 +0000
Subject: [PATCH] [Tosa] Fix TosaValidation for FuncOp

This patch changes TosaValidation pass so that it works as
either a pass on FuncOp or a pass on ModuleOp

Tosa Variable checks are only enabled on ModuleOp because
variable declarations may be outside of functions.

Also added a pass on ModuleOp, --tosa-to-linalg-pipeline
and a test, tosa-to-linalg-pipeline.mlir
that calls the function addTosaToLinalgPasses so it gets
tested

reverted an earlier change in addTosaToLinalgPasses that
changed
addNestedPass<func::FuncOp>(tosa::createTosaValidation(...))
to
addNestedPass<mlir::ModuleOp>(tosa::createTosaValidation(...))

because it is no longer necessary.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Ie0fb6a09c7dd8d4bd5304e283810a5f65f55e912
---
 mlir/include/mlir/Conversion/Passes.td        | 24 +++++--
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  1 +
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  2 +-
 .../Conversion/TosaToLinalg/CMakeLists.txt    |  1 +
 .../TosaToLinalg/TosaToLinalgPass.cpp         |  3 +-
 .../TosaToLinalg/TosaToLinalgPipeline.cpp     | 66 +++++++++++++++++++
 .../Tosa/Transforms/TosaValidation.cpp        |  9 +--
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 50 ++++++++++++++
 8 files changed, 145 insertions(+), 11 deletions(-)
 create mode 100644 mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
 create mode 100644 mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 274784fe4a7b29c..f1df226d7058955 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     already present in the IR will be kept as is.
 
     An LLVM datalayout string can be attached as an attribute to the module on
-    which the pass anchors. Such an attribute is attached by calling the 
-    set-module-datalayout pass. If present, an llvm::DataLayout object is 
+    which the pass anchors. Such an attribute is attached by calling the
+    set-module-datalayout pass. If present, an llvm::DataLayout object is
     created from this attribute and used in the conversion to LLVM.
 
     #### Output IR
@@ -816,12 +816,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
   let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
   let description = [{
-    This pass generates PTX instructions using inline assembly for NVVM 
+    This pass generates PTX instructions using inline assembly for NVVM
     operations implements `BasicPtxBuilderInterface`.
   }];
   let dependentDialects = [
     "NVVM::NVVMDialect",
-  ];  
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1129,6 +1129,22 @@ def TosaToLinalgNamed
   let constructor = "tosa::createTosaToLinalgNamed()";
 }
 
+//===----------------------------------------------------------------------===//
+// TosaToLinalgPipeline
+//===----------------------------------------------------------------------===//
+
+def TosaToLinalgPipeline
+    : Pass<"tosa-to-linalg-pipeline", "ModuleOp"> {
+  let summary = "Lower TOSA to LinAlg on tensors and named operations with validation";
+  let description = [{
+    Pass that converts TOSA operations to the equivalent operations using the
+    tensor operations in LinAlg as well as LinAlg named operations.
+    This invokes addTosaToLinalgPasses pipeline to allow testing.
+  }];
+
+  let constructor = "tosa::createTosaToLinalgPipeline()";
+}
+
 //===----------------------------------------------------------------------===//
 // TosaToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index c411010603ac61f..19906461892501f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -27,6 +27,7 @@ namespace tosa {
 
 std::unique_ptr<Pass> createTosaToLinalg();
 std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgPipeline();
 
 /// Populates passes to convert from TOSA to Linalg on buffers. At the end of
 /// the pass, the function will only contain linalg ops or standard ops if the
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150fb..81932ba8b8dd38a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,7 +89,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
   let cppNamespace = "mlir::tosa";
 }
 
-def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
+def TosaValidation : Pass<"tosa-validate"> {
   let summary = "Validates TOSA dialect";
   let description = [{
     This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 4b79bf82810c58d..f35cbc9e8dcd1fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
   TosaToLinalgNamed.cpp
   TosaToLinalgNamedPass.cpp
   TosaToLinalgPass.cpp
+  TosaToLinalgPipeline.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 2072fabc29242eb..3c54f85b033b0b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
       {options.aggressiveReduceConstant}));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addNestedPass<mlir::ModuleOp>(
-      tosa::createTosaValidation(validationOptions));
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
new file mode 100644
index 000000000000000..4c2674f042da87c
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
@@ -0,0 +1,66 @@
+//===- TosaToLinalgPipeline.cpp - Lowering Tosa to Linalg Dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOLINALGPIPELINE
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct TosaToLinalgPipeline
+    : public impl::TosaToLinalgPipelineBase<TosaToLinalgPipeline> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
+                tensor::TensorDialect, scf::SCFDialect>();
+  }
+
+  void runOnOperation() override {
+    OpPassManager pm("builtin.module");
+
+    TosaToLinalgOptions tosaToLinalgOptions;
+
+    tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
+                                /* validationOptions = */
+                                {tosa::TosaProfileEnum::BaseInference,
+                                 /* StrictOperationSpecAlignment = */ true,
+                                 tosa::TosaLevelEnum::EightK});
+
+    if (failed(runPipeline(pm, getOperation())))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgPipeline() {
+  return std::make_unique<TosaToLinalgPipeline>();
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 424a31175d61707..88bf02205c689f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
 
 #include <string>
-#include <unordered_map>
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -506,7 +505,9 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
 
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
-  getOperation().walk([&](Operation *op) {
+  Operation *topOp = getOperation();
+  const bool isModule = isa<ModuleOp>(topOp);
+  topOp->walk([&](Operation *op) {
     for (Value operand : op->getOperands()) {
       if ((profile == TosaProfileEnum::BaseInference) &&
           isa<FloatType>(getElementTypeOrSelf(operand))) {
@@ -526,8 +527,8 @@ void TosaValidation::runOnOperation() {
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
 
-    // do variable type checks
-    if (failed(applyVariableCheck(op)))
+    // do variable type checks iff topOp is a ModuleOp
+    if (isModule && failed(applyVariableCheck(op)))
       signalPassFailure();
   });
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
new file mode 100644
index 000000000000000..5477b12ff13463c
--- /dev/null
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
+
+
+// -----
+
+// check that -tosa-validate of stateful ops do not kick in
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+  // expected-error at +1 {{failed to legalize operation 'tosa.variable'}}
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+  return
+}
+
+// -----
+
+// check that -tosa-validate level checking kick in
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+  // expected-error at +1 {{'tosa.abs' op failed level check: unranked tensor}}
+  %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
+  return %0 : tensor<*xi8>
+}
+
+// -----
+
+// check that tosa verify kick in
+func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
+      : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
+    return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+
+// check that --tosa-to-linalg kick in
+func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.avg_pool2d'}}
+  %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
+  return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
+}
+
+// -----
+
+// check that --tosa-validate=strict-op-spec-alignment does not kick in because tosa-to-linalg-named comes before tosa-validate
+// this would have failed tosa strict-op-spec-alignment because perms of transpose is not constant
+// but tosa.transpose is lowered by tosa-to-linalg-named pass which is earlier than tosa-validate pass in the pipeline
+func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf32>
+}



More information about the Mlir-commits mailing list