[Mlir-commits] [mlir] [mlir][tosa] Introduce arith.constant -> tosa.const normalization pass (PR #168370)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 17 06:07:11 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Vitalii Shutov (Lallapallooza)

<details>
<summary>Changes</summary>

Add a standalone pass that rewrites tensor-valued `arith.constant` ops into `tosa.const`, normalize the TOSA backend contract.

---
Full diff: https://github.com/llvm/llvm-project/pull/168370.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+9) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp (+126) 
- (added) mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir (+110) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 14b00b04ccc18..34572c5c4d131 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
   }];
 }
 
+def TosaArithConstantToTosaConstPass
+    : Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> {
+  let summary = "Convert tensor arith.constant operations into tosa.const";
+  let description = [{
+    Normalizes tensor-valued arith.constant operations into tosa.const so that
+    subsequent TOSA passes operate on a consistent representation of constants.
+  }];
+}
+
 def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
   let summary = "Convert integer types to signless";
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 41b338d6e7189..46c299834e2df 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRTosaTransforms
   TosaAttachTarget.cpp
+  TosaArithConstantToConst.cpp
   TosaConvertIntegerTypeToSignless.cpp
   TosaDecomposeTransposeConv.cpp
   TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
new file mode 100644
index 0000000000000..8ddde9c05724e
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
@@ -0,0 +1,126 @@
+//===- TosaArithConstantToConst.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that converts tensor-valued arith.constant ops
+// into tosa.const so that TOSA pipelines operate on a uniform constant form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// NOTE: TOSA pipelines already lower their constants through shared Arith
+// folding passes, so tensor literals often come back as `arith.constant` even
+// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
+// the TOSA transforms so any client can re-establish a canonical `tosa.const`
+// representation without needing a full Arith->TOSA conversion library.
+
+/// Returns true when `elementType` is natively representable by tosa.const.
+static bool isSupportedElementType(Type elementType) {
+  if (isa<FloatType>(elementType))
+    return true;
+
+  if (auto intType = dyn_cast<IntegerType>(elementType))
+    return intType.isSignless() || intType.isUnsigned();
+
+  if (isa<quant::QuantizedType>(elementType))
+    return true;
+
+  if (isa<tosa::mxint8Type>(elementType))
+    return true;
+
+  return false;
+}
+
+class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp constOp,
+                                PatternRewriter &rewriter) const override {
+    // TOSA constant verification requires a ranked, statically shaped tensor.
+    auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
+    if (!resultType || !resultType.hasStaticShape())
+      return failure();
+
+    if (!isSupportedElementType(resultType.getElementType()))
+      return failure();
+
+    Attribute attr = constOp.getValueAttr();
+    auto elementsAttr = dyn_cast<ElementsAttr>(attr);
+    if (!elementsAttr)
+      return failure();
+
+    auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
+    if (!attrType || !attrType.hasStaticShape())
+      return failure();
+
+    if (attrType != resultType) {
+      // Allow reshape when the payload can be reinterpreted without altering
+      // the number of elements or element type. Dense resource attributes
+      // cannot be reshaped losslessly, so bail out in that case.
+      if (!isa<DenseElementsAttr>(elementsAttr))
+        return failure();
+
+      if (attrType.getElementType() != resultType.getElementType())
+        return failure();
+
+      auto denseAttr = cast<DenseElementsAttr>(elementsAttr);
+      if (denseAttr.getNumElements() != resultType.getNumElements())
+        return failure();
+
+      elementsAttr = denseAttr.reshape(resultType);
+    }
+
+    auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
+                                          resultType, elementsAttr);
+    rewriter.replaceOp(constOp, newConst.getResult());
+    return success();
+  }
+};
+
+struct TosaArithConstantToTosaConstPass
+    : public tosa::impl::TosaArithConstantToTosaConstPassBase<
+          TosaArithConstantToTosaConstPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, tosa::TosaDialect>();
+  }
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ArithConstantToTosaConst>(ctx);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
new file mode 100644
index 0000000000000..3f54a68ed3c00
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @rewrite_f32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: return %[[CST]]
+func.func @rewrite_f32_tensor() -> tensor<2xf32> {
+  %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
+  return %c : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK: return %[[CST]]
+func.func @rewrite_i32_tensor() -> tensor<3xi32> {
+  %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32>
+  return %c : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i1_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1>
+func.func @rewrite_i1_tensor() -> tensor<2xi1> {
+  %c = arith.constant dense<[true, false]> : tensor<2xi1>
+  return %c : tensor<2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_rank0_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32>
+func.func @rewrite_rank0_tensor() -> tensor<f32> {
+  %c = arith.constant dense<1.234500e+00> : tensor<f32>
+  return %c : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_scalar_i32
+// CHECK: %[[CST:.*]] = arith.constant 42 : i32
+func.func @preserve_scalar_i32() -> i32 {
+  %c = arith.constant 42 : i32
+  return %c : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_index_tensor
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex>
+func.func @preserve_index_tensor() -> tensor<2xindex> {
+  %c = arith.constant dense<[0, 1]> : tensor<2xindex>
+  return %c : tensor<2xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_resource_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32>
+func.func @rewrite_resource_tensor() -> tensor<4xf32> {
+  %c = arith.constant dense_resource<"blob1"> : tensor<4xf32>
+  return %c : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8>
+func.func @rewrite_quant_tensor() -> tensor<2xui8> {
+  %c = arith.constant dense<[10, 20]> : tensor<2xui8>
+  return %c : tensor<2xui8>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>
+func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> {
+  %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+  return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_reshape_collapse_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 2, 3, 4]> : tensor<4xi32>}> : () -> tensor<4xi32>
+func.func @rewrite_reshape_collapse_tensor() -> tensor<4xi32> {
+  %c = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %d = tensor.collapse_shape %c [[0, 1]] : tensor<2x2xi32> into tensor<4xi32>
+  return %d : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_fp8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN>
+func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> {
+  %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN>
+  return %c : tensor<2xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_mxint8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8>
+func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> {
+  %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>
+  return %c : tensor<2x!tosa.mxint8>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/168370


More information about the Mlir-commits mailing list