[llvm] [mlir] [TOSA] Add TosaToMLProgram conversion (PR #69787)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 8 16:31:47 PST 2023


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/69787

>From 0edbdad3630dda39f366d425fbe072f5877d2ffd Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Fri, 20 Oct 2023 14:46:11 -0700
Subject: [PATCH] [TOSA] Add a pass to convert TOSA Variable Ops to MLProgram
 Global Ops

The TOSA variable ops and ml_program ops offer similar functionality.
The tosa-to-mlprogram pass defines legalizations from the TOSA op to
the MLProgram equivalent op.

tosa.variable maps to ml_program.global.
tosa.variable_read maps to ml_program.global_load
tosa.varaible_write maps to ml_program.global_store

Signed-off-by: Jerry Ge <jerry.ge at arm.com>
---
 mlir/include/mlir/Conversion/Passes.h         |  1 +
 mlir/include/mlir/Conversion/Passes.td        | 21 ++++-
 .../TosaToMLProgram/TosaToMLProgram.h         | 30 ++++++++
 mlir/lib/Conversion/CMakeLists.txt            |  1 +
 .../Conversion/TosaToMLProgram/CMakeLists.txt | 19 +++++
 .../TosaToMLProgram/TosaToMLProgram.cpp       | 76 +++++++++++++++++++
 .../TosaToMLProgram/TosaToMLProgramPass.cpp   | 48 ++++++++++++
 .../TosaToMLProgram/tosa-to-mlprogram.mlir    | 13 ++++
 .../llvm-project-overlay/mlir/BUILD.bazel     | 25 ++++++
 9 files changed, 230 insertions(+), 4 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
 create mode 100644 mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
 create mode 100644 mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
 create mode 100644 mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..637b69fc3f157b9 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -59,6 +59,7 @@
 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
 #include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a269fb4a83af41f..082aabc93428875 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     already present in the IR will be kept as is.
 
     An LLVM datalayout string can be attached as an attribute to the module on
-    which the pass anchors. Such an attribute is attached by calling the 
-    set-module-datalayout pass. If present, an llvm::DataLayout object is 
+    which the pass anchors. Such an attribute is attached by calling the
+    set-module-datalayout pass. If present, an llvm::DataLayout object is
     created from this attribute and used in the conversion to LLVM.
 
     #### Output IR
@@ -794,12 +794,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
   let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
   let description = [{
-    This pass generates PTX instructions using inline assembly for NVVM 
+    This pass generates PTX instructions using inline assembly for NVVM
     operations implements `BasicPtxBuilderInterface`.
   }];
   let dependentDialects = [
     "NVVM::NVVMDialect",
-  ];  
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1107,6 +1107,19 @@ def TosaToLinalgNamed
   let constructor = "tosa::createTosaToLinalgNamed()";
 }
 
+//===----------------------------------------------------------------------===//
+// TosaToMLProgram
+//===----------------------------------------------------------------------===//
+
+def TosaToMLProgram : Pass<"tosa-to-mlprogram", "ModuleOp"> {
+  let summary = "Lower TOSA to the MLProgram dialect";
+  let dependentDialects = ["ml_program::MLProgramDialect"];
+  let description = [{
+    Pass that converts TOSA's variable operator operations to the equivalent
+    MLProgram operations.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // TosaToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h b/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
new file mode 100644
index 000000000000000..f11c543cc78824d
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
@@ -0,0 +1,30 @@
+//===-- TosaToMLProgram.h - TOSA to MLProgram dialect lowerings-*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to MLProgram Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
+#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_TOSATOMLPROGRAM
+
+namespace tosa {
+
+void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..664804f0453509f 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -49,6 +49,7 @@ add_subdirectory(TensorToLinalg)
 add_subdirectory(TensorToSPIRV)
 add_subdirectory(TosaToArith)
 add_subdirectory(TosaToLinalg)
+add_subdirectory(TosaToMLProgram)
 add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
diff --git a/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt b/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
new file mode 100644
index 000000000000000..82941424f1d1025
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToMLProgram
+  TosaToMLProgram.cpp
+  TosaToMLProgramPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMLProgramDialect
+  MLIRPass
+  MLIRTosaDialect
+  MLIRTosaTransforms
+  MLIRSupport
+  )
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
new file mode 100644
index 000000000000000..d134d8cdf485e43
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -0,0 +1,76 @@
+//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the TOSA dialect to the MLProgram dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace tosa;
+namespace {
+
+class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
+public:
+  using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::VariableOp op,
+                                PatternRewriter &rewriter) const final {
+    auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
+        op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
+        op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
+    newVariable.setPrivate();
+    rewriter.replaceOp(op, newVariable);
+    return success();
+  }
+};
+
+class VariableWriteOpConverter
+    : public OpRewritePattern<tosa::VariableWriteOp> {
+public:
+  using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
+                                PatternRewriter &rewriter) const final {
+    auto globalSymbolRef =
+        SymbolRefAttr::get(rewriter.getContext(), op.getName());
+    auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
+        op.getLoc(), globalSymbolRef, op.getValue());
+    rewriter.replaceOp(op, newVariableWrite);
+    return success();
+  }
+};
+
+class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
+public:
+  using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::VariableReadOp op,
+                                PatternRewriter &rewriter) const final {
+    auto globalSymbolRef =
+        SymbolRefAttr::get(rewriter.getContext(), op.getName());
+    auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
+        op.getLoc(), op.getType(), globalSymbolRef);
+    rewriter.replaceOp(op, newVariableRead);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToMLProgramConversionPatterns(
+    RewritePatternSet *patterns) {
+  patterns->add<VariableOpConverter, VariableWriteOpConverter,
+                VariableReadOpConverter>(patterns->getContext());
+}
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
new file mode 100644
index 000000000000000..8c39f5e8a63631d
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
@@ -0,0 +1,48 @@
+//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect--------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes the TOSA dialect to the MLProgram dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOMLPROGRAM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> {
+public:
+  void runOnOperation() override {
+    auto *context = &getContext();
+    auto moduleOp = getOperation();
+
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+    target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
+                        tosa::VariableWriteOp>();
+    target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+    mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);
+
+    if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
new file mode 100644
index 000000000000000..69b6875987daf17
--- /dev/null
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
+
+module {
+  // CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
+  tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
+  func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
+    // CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
+    tosa.variable.write @var_x, %arg0 : tensor<1xf32>
+    // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
+    %0 = tosa.variable.read @var_x : tensor<1xf32>
+    return %0 : tensor<1xf32>
+  }
+}
\ No newline at end of file
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index eb670ad50163c38..3b2b9f2660164c7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3730,6 +3730,7 @@ cc_library(
         ":TensorToSPIRV",
         ":TosaToArith",
         ":TosaToLinalg",
+        ":TosaToMLProgram",
         ":TosaToSCF",
         ":TosaToTensor",
         ":UBToLLVM",
@@ -11054,6 +11055,30 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TosaToMLProgram",
+    srcs = glob([
+        "lib/Conversion/TosaToMLProgram/*.cpp",
+        "lib/Conversion/TosaToMLProgram/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/TosaToMLProgram/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/TosaToMLProgram",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":FuncDialect",
+        ":IR",
+        ":Pass",
+        ":MLProgramDialect",
+        ":TosaDialect",
+        ":Transforms",
+    ],
+)
+
 cc_library(
     name = "TosaToSCF",
     srcs = glob([



More information about the llvm-commits mailing list