[Mlir-commits] [mlir] 39f5ff0 - [mlir][tosa] Introduce arith.constant -> tosa.const normalization pass (#168370)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 27 07:55:34 PST 2025


Author: Vitalii Shutov
Date: 2025-11-27T15:55:30Z
New Revision: 39f5ff056bc459c7db4d01c348fe78925da8c558

URL: https://github.com/llvm/llvm-project/commit/39f5ff056bc459c7db4d01c348fe78925da8c558
DIFF: https://github.com/llvm/llvm-project/commit/39f5ff056bc459c7db4d01c348fe78925da8c558.diff

LOG: [mlir][tosa] Introduce arith.constant -> tosa.const normalization pass (#168370)

Add a standalone pass that rewrites tensor-valued `arith.constant` ops
into `tosa.const`, normalize the TOSA backend contract.

Signed-off-by: Vitalii Shutov <vitalii.shutov at arm.com>
Co-authored-by: Shubham <shubham at arm.com>

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
    mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 420e58192b8fd..12f520297b702 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
   }];
 }
 
+def TosaArithConstantToTosaConstPass
+    : Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> {
+  let summary = "Convert tensor arith.constant operations into tosa.const";
+  let description = [{
+    Normalizes tensor-valued arith.constant operations into tosa.const so that
+    subsequent TOSA passes operate on a consistent representation of constants.
+  }];
+}
+
 def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
   let summary = "Convert integer types to signless";
   let description = [{

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 4aa5b4523bbe6..091b481d6394b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRTosaTransforms
   TosaAttachTarget.cpp
+  TosaArithConstantToConst.cpp
   TosaConvertIntegerTypeToSignless.cpp
   TosaDecomposeTransposeConv.cpp
   TosaDecomposeDepthwise.cpp

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
new file mode 100644
index 0000000000000..73e1e2bee3399
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
@@ -0,0 +1,111 @@
+//===- TosaArithConstantToConst.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that converts tensor-valued arith.constant ops
+// into tosa.const so that TOSA pipelines operate on a uniform constant form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// NOTE: TOSA pipelines already lower their constants through shared Arith
+// folding passes, so tensor literals often come back as `arith.constant` even
+// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
+// the TOSA transforms so any client can re-establish a canonical `tosa.const`
+// representation without needing a full Arith->TOSA conversion library.
+
+/// Returns true when `elementType` is natively representable by tosa.const.
+static bool isSupportedElementType(Type elementType) {
+  if (isa<FloatType>(elementType))
+    return true;
+
+  if (auto intType = dyn_cast<IntegerType>(elementType))
+    return intType.isSignless() || intType.isUnsigned();
+
+  if (isa<quant::QuantizedType>(elementType))
+    return true;
+
+  if (isa<tosa::mxint8Type>(elementType))
+    return true;
+
+  return false;
+}
+
+class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp constOp,
+                                PatternRewriter &rewriter) const override {
+    // TOSA constant verification requires a ranked, statically shaped tensor.
+    auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
+    if (!resultType || !resultType.hasStaticShape())
+      return failure();
+
+    if (!isSupportedElementType(resultType.getElementType()))
+      return failure();
+
+    Attribute attr = constOp.getValueAttr();
+    auto elementsAttr = dyn_cast<ElementsAttr>(attr);
+    if (!elementsAttr)
+      return failure();
+
+    auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
+    if (!attrType || !attrType.hasStaticShape())
+      return failure();
+    if (attrType != resultType)
+      return failure();
+
+    auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
+                                          resultType, elementsAttr);
+    rewriter.replaceOp(constOp, newConst.getResult());
+    return success();
+  }
+};
+
+struct TosaArithConstantToTosaConstPass
+    : public tosa::impl::TosaArithConstantToTosaConstPassBase<
+          TosaArithConstantToTosaConstPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, tosa::TosaDialect>();
+  }
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ArithConstantToTosaConst>(ctx);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace

diff  --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
new file mode 100644
index 0000000000000..fc2d77ef375ec
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @rewrite_f32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: return %[[CST]]
+func.func @rewrite_f32_tensor() -> tensor<2xf32> {
+  %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
+  return %c : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK: return %[[CST]]
+func.func @rewrite_i32_tensor() -> tensor<3xi32> {
+  %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32>
+  return %c : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i1_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1>
+func.func @rewrite_i1_tensor() -> tensor<2xi1> {
+  %c = arith.constant dense<[true, false]> : tensor<2xi1>
+  return %c : tensor<2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_rank0_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32>
+func.func @rewrite_rank0_tensor() -> tensor<f32> {
+  %c = arith.constant dense<1.234500e+00> : tensor<f32>
+  return %c : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_scalar_i32
+// CHECK: %[[CST:.*]] = arith.constant 42 : i32
+func.func @preserve_scalar_i32() -> i32 {
+  %c = arith.constant 42 : i32
+  return %c : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_index_tensor
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex>
+func.func @preserve_index_tensor() -> tensor<2xindex> {
+  %c = arith.constant dense<[0, 1]> : tensor<2xindex>
+  return %c : tensor<2xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_resource_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32>
+func.func @rewrite_resource_tensor() -> tensor<4xf32> {
+  %c = arith.constant dense_resource<"blob1"> : tensor<4xf32>
+  return %c : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8>
+func.func @rewrite_quant_tensor() -> tensor<2xui8> {
+  %c = arith.constant dense<[10, 20]> : tensor<2xui8>
+  return %c : tensor<2xui8>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>
+func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> {
+  %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+  return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_fp8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN>
+func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> {
+  %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN>
+  return %c : tensor<2xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_mxint8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8>
+func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> {
+  %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>
+  return %c : tensor<2x!tosa.mxint8>
+}


        


More information about the Mlir-commits mailing list