[Mlir-commits] [mlir] [Tosa] Fix TosaValidation for FuncOp (PR #69997)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 23 18:35:16 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This patch changes TosaValidation pass so that it works as either a pass on FuncOp or a pass on ModuleOp

Tosa Variable checks are only enabled on ModuleOp because variable declarations may be outside of functions.

Also added a pass on ModuleOp, --tosa-to-linalg-pipeline and a test, tosa-to-linalg-pipeline.mlir
that calls the function addTosaToLinalgPasses so it gets tested

---
Full diff: https://github.com/llvm/llvm-project/pull/69997.diff


7 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+20-4) 
- (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+1) 
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+1-1) 
- (modified) mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt (+1) 
- (added) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp (+65) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+5-4) 
- (added) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+31) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 274784fe4a7b29c..f1df226d7058955 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     already present in the IR will be kept as is.
 
     An LLVM datalayout string can be attached as an attribute to the module on
-    which the pass anchors. Such an attribute is attached by calling the 
-    set-module-datalayout pass. If present, an llvm::DataLayout object is 
+    which the pass anchors. Such an attribute is attached by calling the
+    set-module-datalayout pass. If present, an llvm::DataLayout object is
     created from this attribute and used in the conversion to LLVM.
 
     #### Output IR
@@ -816,12 +816,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
   let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
   let description = [{
-    This pass generates PTX instructions using inline assembly for NVVM 
+    This pass generates PTX instructions using inline assembly for NVVM
     operations implements `BasicPtxBuilderInterface`.
   }];
   let dependentDialects = [
     "NVVM::NVVMDialect",
-  ];  
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1129,6 +1129,22 @@ def TosaToLinalgNamed
   let constructor = "tosa::createTosaToLinalgNamed()";
 }
 
+//===----------------------------------------------------------------------===//
+// TosaToLinalgPipeline
+//===----------------------------------------------------------------------===//
+
+def TosaToLinalgPipeline
+    : Pass<"tosa-to-linalg-pipeline", "ModuleOp"> {
+  let summary = "Lower TOSA to LinAlg on tensors and named operations with validation";
+  let description = [{
+    Pass that converts TOSA operations to the equivalent operations using the
+    tensor operations in LinAlg as well as LinAlg named operations.
+    This invokes addTosaToLinalgPasses pipeline to allow testing.
+  }];
+
+  let constructor = "tosa::createTosaToLinalgPipeline()";
+}
+
 //===----------------------------------------------------------------------===//
 // TosaToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index c411010603ac61f..19906461892501f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -27,6 +27,7 @@ namespace tosa {
 
 std::unique_ptr<Pass> createTosaToLinalg();
 std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgPipeline();
 
 /// Populates passes to convert from TOSA to Linalg on buffers. At the end of
 /// the pass, the function will only contain linalg ops or standard ops if the
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150fb..81932ba8b8dd38a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,7 +89,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
   let cppNamespace = "mlir::tosa";
 }
 
-def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
+def TosaValidation : Pass<"tosa-validate"> {
   let summary = "Validates TOSA dialect";
   let description = [{
     This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 4b79bf82810c58d..f35cbc9e8dcd1fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
   TosaToLinalgNamed.cpp
   TosaToLinalgNamedPass.cpp
   TosaToLinalgPass.cpp
+  TosaToLinalgPipeline.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
new file mode 100644
index 000000000000000..514011fc92accc0
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
@@ -0,0 +1,65 @@
+//===- TosaToLinalgPipeline.cpp - Lowering Tosa to Linalg Dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOLINALGPIPELINE
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct TosaToLinalgPipeline
+    : public impl::TosaToLinalgPipelineBase<TosaToLinalgPipeline> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
+                tensor::TensorDialect, scf::SCFDialect>();
+  }
+
+  void runOnOperation() override {
+    OpPassManager pm("builtin.module");
+
+    tosa::addTosaToLinalgPasses(pm,
+                                /* disableTosaDecompositions = */ false,
+                                /* validationOptions = */
+                                {tosa::TosaProfileEnum::BaseInference,
+                                 /* StrictOperationSpecAlignment = */ true,
+                                 tosa::TosaLevelEnum::EightK});
+
+    if (failed(runPipeline(pm, getOperation())))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgPipeline() {
+  return std::make_unique<TosaToLinalgPipeline>();
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 424a31175d61707..88bf02205c689f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
 
 #include <string>
-#include <unordered_map>
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -506,7 +505,9 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
 
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
-  getOperation().walk([&](Operation *op) {
+  Operation *topOp = getOperation();
+  const bool isModule = isa<ModuleOp>(topOp);
+  topOp->walk([&](Operation *op) {
     for (Value operand : op->getOperands()) {
       if ((profile == TosaProfileEnum::BaseInference) &&
           isa<FloatType>(getElementTypeOrSelf(operand))) {
@@ -526,8 +527,8 @@ void TosaValidation::runOnOperation() {
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
 
-    // do variable type checks
-    if (failed(applyVariableCheck(op)))
+    // do variable type checks iff topOp is a ModuleOp
+    if (isModule && failed(applyVariableCheck(op)))
       signalPassFailure();
   });
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
new file mode 100644
index 000000000000000..ff932af18926464
--- /dev/null
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
+
+
+// -----
+
+// check that -tosa-validate of stateful ops do not kick in
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+  // expected-error at +1 {{failed to legalize operation 'tosa.variable'}}
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+  return
+}
+
+// -----
+
+// check that --tosa-to-linalg kick in
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.abs'}}
+  %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
+  return %0 : tensor<*xi8>
+}
+
+// -----
+
+// check that --tosa-validate=strict-op-spec-alignment kick in
+func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
+      : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
+    return %0 : tensor<1x7x7x9xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/69997


More information about the Mlir-commits mailing list