[Mlir-commits] [mlir] [Tosa] Fix TosaValidation for FuncOp (PR #69997)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 23 18:35:16 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This patch changes TosaValidation pass so that it works as either a pass on FuncOp or a pass on ModuleOp
Tosa Variable checks are only enabled on ModuleOp because variable declarations may be outside of functions.
Also added a pass on ModuleOp, --tosa-to-linalg-pipeline and a test, tosa-to-linalg-pipeline.mlir
that calls the function addTosaToLinalgPasses so it gets tested
---
Full diff: https://github.com/llvm/llvm-project/pull/69997.diff
7 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+20-4)
- (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+1)
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+1-1)
- (modified) mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt (+1)
- (added) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp (+65)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+5-4)
- (added) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+31)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 274784fe4a7b29c..f1df226d7058955 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
already present in the IR will be kept as is.
An LLVM datalayout string can be attached as an attribute to the module on
- which the pass anchors. Such an attribute is attached by calling the
- set-module-datalayout pass. If present, an llvm::DataLayout object is
+ which the pass anchors. Such an attribute is attached by calling the
+ set-module-datalayout pass. If present, an llvm::DataLayout object is
created from this attribute and used in the conversion to LLVM.
#### Output IR
@@ -816,12 +816,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
let description = [{
- This pass generates PTX instructions using inline assembly for NVVM
+ This pass generates PTX instructions using inline assembly for NVVM
operations implements `BasicPtxBuilderInterface`.
}];
let dependentDialects = [
"NVVM::NVVMDialect",
- ];
+ ];
}
//===----------------------------------------------------------------------===//
@@ -1129,6 +1129,22 @@ def TosaToLinalgNamed
let constructor = "tosa::createTosaToLinalgNamed()";
}
+//===----------------------------------------------------------------------===//
+// TosaToLinalgPipeline
+//===----------------------------------------------------------------------===//
+
+def TosaToLinalgPipeline
+ : Pass<"tosa-to-linalg-pipeline", "ModuleOp"> {
+ let summary = "Lower TOSA to LinAlg on tensors and named operations with validation";
+ let description = [{
+ Pass that converts TOSA operations to the equivalent operations using the
+ tensor operations in LinAlg as well as LinAlg named operations.
+ This invokes addTosaToLinalgPasses pipeline to allow testing.
+ }];
+
+ let constructor = "tosa::createTosaToLinalgPipeline()";
+}
+
//===----------------------------------------------------------------------===//
// TosaToSCF
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index c411010603ac61f..19906461892501f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -27,6 +27,7 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToLinalg();
std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgPipeline();
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
/// the pass, the function will only contain linalg ops or standard ops if the
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150fb..81932ba8b8dd38a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,7 +89,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}
-def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
+def TosaValidation : Pass<"tosa-validate"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 4b79bf82810c58d..f35cbc9e8dcd1fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
TosaToLinalgNamed.cpp
TosaToLinalgNamedPass.cpp
TosaToLinalgPass.cpp
+ TosaToLinalgPipeline.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
new file mode 100644
index 000000000000000..514011fc92accc0
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
@@ -0,0 +1,65 @@
+//===- TosaToLinalgPipeline.cpp - Lowering Tosa to Linalg Dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOLINALGPIPELINE
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct TosaToLinalgPipeline
+ : public impl::TosaToLinalgPipelineBase<TosaToLinalgPipeline> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
+ }
+
+ void runOnOperation() override {
+ OpPassManager pm("builtin.module");
+
+ tosa::addTosaToLinalgPasses(pm,
+ /* disableTosaDecompositions = */ false,
+ /* validationOptions = */
+ {tosa::TosaProfileEnum::BaseInference,
+ /* StrictOperationSpecAlignment = */ true,
+ tosa::TosaLevelEnum::EightK});
+
+ if (failed(runPipeline(pm, getOperation())))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgPipeline() {
+ return std::make_unique<TosaToLinalgPipeline>();
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 424a31175d61707..88bf02205c689f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
-#include <unordered_map>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -506,7 +505,9 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
void TosaValidation::runOnOperation() {
configLevelAndProfile();
- getOperation().walk([&](Operation *op) {
+ Operation *topOp = getOperation();
+ const bool isModule = isa<ModuleOp>(topOp);
+ topOp->walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profile == TosaProfileEnum::BaseInference) &&
isa<FloatType>(getElementTypeOrSelf(operand))) {
@@ -526,8 +527,8 @@ void TosaValidation::runOnOperation() {
if (failed(applyLevelCheck(op)))
signalPassFailure();
- // do variable type checks
- if (failed(applyVariableCheck(op)))
+ // do variable type checks iff topOp is a ModuleOp
+ if (isModule && failed(applyVariableCheck(op)))
signalPassFailure();
});
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
new file mode 100644
index 000000000000000..ff932af18926464
--- /dev/null
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
+
+
+// -----
+
+// check that -tosa-validate of stateful ops do not kick in
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+ // expected-error at +1 {{failed to legalize operation 'tosa.variable'}}
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+// check that --tosa-to-linalg kick in
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+ // expected-error at +1 {{failed to legalize operation 'tosa.abs'}}
+ %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
+ return %0 : tensor<*xi8>
+}
+
+// -----
+
+// check that --tosa-validate=strict-op-spec-alignment kick in
+func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
+ // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
+ : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
+ return %0 : tensor<1x7x7x9xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/69997
More information about the Mlir-commits
mailing list