[Mlir-commits] [mlir] [mlir][tosa] Introduce arith.constant -> tosa.const normalization pass (PR #168370)
Vitalii Shutov
llvmlistbot at llvm.org
Thu Nov 27 07:13:43 PST 2025
https://github.com/Lallapallooza updated https://github.com/llvm/llvm-project/pull/168370
>From 9496f277489dc79a5a7486e4c9f6653fe9c99243 Mon Sep 17 00:00:00 2001
From: Vitalii Shutov <vitalii.shutov at arm.com>
Date: Tue, 11 Nov 2025 17:10:32 +0000
Subject: [PATCH] [TOSA] Introduce arith.constant -> tosa.const normalization
pass
Add a standalone pass that rewrites tensor-valued `arith.constant` ops into
`tosa.const`, normalize the TOSA backend contract.
Co-authored-by: Shubham <shubham at arm.com>
Signed-off-by: Vitalii Shutov <vitalii.shutov at arm.com>
Change-Id: I4e71926107633007a71bd1fcc3311a5da6d38849
---
.../mlir/Dialect/Tosa/Transforms/Passes.td | 9 ++
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../Transforms/TosaArithConstantToConst.cpp | 111 ++++++++++++++++++
.../Tosa/tosa-arith-const-to-tosa-const.mlir | 100 ++++++++++++++++
4 files changed, 221 insertions(+)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 14b00b04ccc18..34572c5c4d131 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
}];
}
+def TosaArithConstantToTosaConstPass
+ : Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> {
+ let summary = "Convert tensor arith.constant operations into tosa.const";
+ let description = [{
+ Normalizes tensor-valued arith.constant operations into tosa.const so that
+ subsequent TOSA passes operate on a consistent representation of constants.
+ }];
+}
+
def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
let summary = "Convert integer types to signless";
let description = [{
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 41b338d6e7189..46c299834e2df 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaAttachTarget.cpp
+ TosaArithConstantToConst.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
new file mode 100644
index 0000000000000..73e1e2bee3399
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
@@ -0,0 +1,111 @@
+//===- TosaArithConstantToConst.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that converts tensor-valued arith.constant ops
+// into tosa.const so that TOSA pipelines operate on a uniform constant form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// NOTE: TOSA pipelines already lower their constants through shared Arith
+// folding passes, so tensor literals often come back as `arith.constant` even
+// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
+// the TOSA transforms so any client can re-establish a canonical `tosa.const`
+// representation without needing a full Arith->TOSA conversion library.
+
+/// Returns true when `elementType` is natively representable by tosa.const.
+static bool isSupportedElementType(Type elementType) {
+ if (isa<FloatType>(elementType))
+ return true;
+
+ if (auto intType = dyn_cast<IntegerType>(elementType))
+ return intType.isSignless() || intType.isUnsigned();
+
+ if (isa<quant::QuantizedType>(elementType))
+ return true;
+
+ if (isa<tosa::mxint8Type>(elementType))
+ return true;
+
+ return false;
+}
+
+class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constOp,
+ PatternRewriter &rewriter) const override {
+ // TOSA constant verification requires a ranked, statically shaped tensor.
+ auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
+ if (!resultType || !resultType.hasStaticShape())
+ return failure();
+
+ if (!isSupportedElementType(resultType.getElementType()))
+ return failure();
+
+ Attribute attr = constOp.getValueAttr();
+ auto elementsAttr = dyn_cast<ElementsAttr>(attr);
+ if (!elementsAttr)
+ return failure();
+
+ auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
+ if (!attrType || !attrType.hasStaticShape())
+ return failure();
+ if (attrType != resultType)
+ return failure();
+
+ auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
+ resultType, elementsAttr);
+ rewriter.replaceOp(constOp, newConst.getResult());
+ return success();
+ }
+};
+
+struct TosaArithConstantToTosaConstPass
+ : public tosa::impl::TosaArithConstantToTosaConstPassBase<
+ TosaArithConstantToTosaConstPass> {
+ using Base::Base;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, tosa::TosaDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<ArithConstantToTosaConst>(ctx);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
new file mode 100644
index 0000000000000..fc2d77ef375ec
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @rewrite_f32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: return %[[CST]]
+func.func @rewrite_f32_tensor() -> tensor<2xf32> {
+ %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
+ return %c : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK: return %[[CST]]
+func.func @rewrite_i32_tensor() -> tensor<3xi32> {
+ %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32>
+ return %c : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i1_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1>
+func.func @rewrite_i1_tensor() -> tensor<2xi1> {
+ %c = arith.constant dense<[true, false]> : tensor<2xi1>
+ return %c : tensor<2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_rank0_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32>
+func.func @rewrite_rank0_tensor() -> tensor<f32> {
+ %c = arith.constant dense<1.234500e+00> : tensor<f32>
+ return %c : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_scalar_i32
+// CHECK: %[[CST:.*]] = arith.constant 42 : i32
+func.func @preserve_scalar_i32() -> i32 {
+ %c = arith.constant 42 : i32
+ return %c : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_index_tensor
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex>
+func.func @preserve_index_tensor() -> tensor<2xindex> {
+ %c = arith.constant dense<[0, 1]> : tensor<2xindex>
+ return %c : tensor<2xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_resource_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32>
+func.func @rewrite_resource_tensor() -> tensor<4xf32> {
+ %c = arith.constant dense_resource<"blob1"> : tensor<4xf32>
+ return %c : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8>
+func.func @rewrite_quant_tensor() -> tensor<2xui8> {
+ %c = arith.constant dense<[10, 20]> : tensor<2xui8>
+ return %c : tensor<2xui8>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>
+func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> {
+ %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+ return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_fp8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN>
+func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> {
+ %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN>
+ return %c : tensor<2xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_mxint8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8>
+func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> {
+ %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>
+ return %c : tensor<2x!tosa.mxint8>
+}
More information about the Mlir-commits
mailing list