[Mlir-commits] [mlir] [mlir][tosa] Change MatMul zero-point to inputs (PR #129785)

Tai Ly llvmlistbot at llvm.org
Fri Mar 7 08:52:50 PST 2025


https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/129785

>From a5c6469350889c7490d8cdcc4e213fe24c7417c8 Mon Sep 17 00:00:00 2001
From: Udaya Ranga <udaya.ranga at arm.com>
Date: Mon, 16 Dec 2024 16:23:07 +0000
Subject: [PATCH] [mlir][tosa] Change MatMul zero-point to inputs

* Change zero-point attributes to inputs
* Fix relevant mlir tests
* Enhance ShardingInterface in MatMul

Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
Change-Id: Ia58b15cba546a948a6a4d8e8ee26a72cd050de4e
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  | 14 ++--
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 11 +++-
 .../TosaToLinalg/TosaToLinalgNamed.cpp        | 41 +++++++++---
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp |  2 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 64 +++++++++++-------
 .../Tosa/Transforms/TosaProfileCompliance.cpp | 11 +++-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 24 +++++--
 .../Dialect/Mesh/sharding-propagation.mlir    | 66 +++++++++++--------
 mlir/test/Dialect/Tosa/availability.mlir      |  4 +-
 mlir/test/Dialect/Tosa/invalid.mlir           | 42 +++++++++++-
 mlir/test/Dialect/Tosa/level_check.mlir       |  3 +-
 mlir/test/Dialect/Tosa/ops.mlir               |  4 +-
 .../Dialect/Tosa/profile_all_unsupported.mlir |  4 +-
 .../Tosa/profile_pro_fp_unsupported.mlir      |  4 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 36 ++++++----
 15 files changed, 232 insertions(+), 98 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..d3fd4c3d1d3e1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -35,9 +35,11 @@ profileComplianceMap = {
         {fp16T, fp16T, fp32T, fp32T},
         {fp32T, fp32T, fp32T, fp32T}}}}},
     {"tosa.matmul",
-     {{{Profile::pro_int}, {{i8T, i8T, i32T}}},
+     {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
       {{Profile::pro_fp},
-       {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp16T, fp32T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
     {"tosa.max_pool2d",
      {{{Profile::pro_int}, {{i8T, i8T}}},
       {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -273,10 +275,10 @@ extensionComplianceMap = {
       {{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
       {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
     {"tosa.matmul",
-     {{{Extension::int16}, {{i16T, i16T, i48T}}},
-      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
-      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
-      {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
+     {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
     {"tosa.max_pool2d",
      {{{Extension::int16}, {{i16T, i16T}}},
       {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5340ce52d73fc..525aa4806c657 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   let arguments = (ins
     Tosa_Tensor3D:$a,
     Tosa_Tensor3D:$b,
-    OptionalAttr<I32Attr>:$a_zp,
-    OptionalAttr<I32Attr>:$b_zp
+    Tosa_ScalarIntOrFloatTensor:$a_zp,
+    Tosa_ScalarIntOrFloatTensor:$b_zp
   );
 
   let results = (outs
@@ -324,6 +324,13 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
     Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
   ];
 
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getAZeroPoint();
+    FailureOr<int64_t> getBZeroPoint();
+    LogicalResult verifyAZeroPoint(int64_t zp);
+    LogicalResult verifyBZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..13c62b2d3e91c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       return rewriter.notifyMatchFailure(
           op, "weight zero point cannot be statically determined");
 
-    int64_t inputZpVal = *maybeIZp;
-    int64_t weightZpVal = *maybeWZp;
+    const int64_t inputZpVal = *maybeIZp;
+    const int64_t weightZpVal = *maybeWZp;
 
     if (op.verifyInputZeroPoint(inputZpVal).failed())
       return rewriter.notifyMatchFailure(
@@ -466,8 +466,8 @@ class DepthwiseConvConverter
       return rewriter.notifyMatchFailure(
           op, "weight zero point cannot be statically determined");
 
-    int64_t inputZpVal = *maybeIZp;
-    int64_t weightZpVal = *maybeWZp;
+    const int64_t inputZpVal = *maybeIZp;
+    const int64_t weightZpVal = *maybeWZp;
 
     if (op.verifyInputZeroPoint(inputZpVal).failed())
       return rewriter.notifyMatchFailure(
@@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
                            .create<linalg::FillOp>(loc, ValueRange{zero},
                                                    ValueRange{emptyTensor})
                            .result();
-    if (!op.getAZp() && !op.getBZp()) {
+
+    FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
+    FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
+    if (failed(maybeAZp))
+      return rewriter.notifyMatchFailure(
+          op, "input a zero point cannot be statically determined");
+    if (failed(maybeBZp))
+      return rewriter.notifyMatchFailure(
+          op, "input b zero point cannot be statically determined");
+
+    const int64_t aZpVal = *maybeAZp;
+    const int64_t bZpVal = *maybeBZp;
+
+    if (op.verifyAZeroPoint(aZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "input a zero point must be zero for non-int8 integer types");
+
+    if (op.verifyBZeroPoint(bZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "input b zero point must be zero for non-int8 integer types");
+
+    if (aZpVal == 0 && bZpVal == 0) {
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
           op, TypeRange{op.getType()},
           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
       return success();
     }
 
-    auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
-    auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
+    auto aZp = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getI32IntegerAttr(aZpVal));
+    auto bZp = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getI32IntegerAttr(bZpVal));
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -834,8 +857,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
       return rewriter.notifyMatchFailure(
           op, "output zero point could not be statically determined");
 
-    int64_t inputZpVal = *maybeIZp;
-    int64_t outputZpVal = *maybeOZp;
+    const int64_t inputZpVal = *maybeIZp;
+    const int64_t outputZpVal = *maybeOZp;
 
     // Apply padding as necessary.
     llvm::SmallVector<int64_t> pad;
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index ffbb707344b8c..6dcb7c845b21f 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -55,6 +55,8 @@ struct MatMulOpSharding
     SmallVector<AffineMap> maps;
     maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
     maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+    maps.push_back(AffineMap::get(0, 0, {}, ctx));
+    maps.push_back(AffineMap::get(0, 0, {}, ctx));
     maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
     return maps;
   }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 854196250bb0c..f8299e45b4f63 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -636,23 +636,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
                                        OperationState &result, Type outputType,
                                        Value a, Value b) {
-  result.addOperands({a, b});
-  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
+  auto zps = createZPsAsConst(builder, a, b);
+  result.addOperands({a, b, zps.first, zps.second});
 
-  if (quantAttr) {
-    result.addAttribute("a_zp", builder.getI32IntegerAttr(
-                                    static_cast<int32_t>(quantAttr.getAZp())));
-    result.addAttribute("b_zp", builder.getI32IntegerAttr(
-                                    static_cast<int32_t>(quantAttr.getBZp())));
-
-    auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
-    assert(inputType && "Input must be a shaped tensor type!");
-
-    auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
-        inputType.getElementType());
-    assert(inputQType && "Tensor must have quantized datatype!");
-
-    unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
+  Type finalOutputType{outputType};
+  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
+    auto eType = getStorageElementTypeOrSelf(a.getType());
+    auto inputBits = eType.getIntOrFloatBitWidth();
 
     auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
     assert(outputShapedType && "Output must be a shaped type");
@@ -662,11 +652,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
       accElementType = builder.getIntegerType(48);
     else
       accElementType = builder.getI32Type();
-    auto accType = outputShapedType.clone(accElementType);
-    result.addTypes(accType);
-  } else {
-    result.addTypes(outputType);
+
+    finalOutputType = outputShapedType.clone(accElementType);
   }
+  result.addTypes(finalOutputType);
 }
 
 /// Both the tosa.avg_pool2d and unary ops use the same
@@ -1147,16 +1136,39 @@ LogicalResult MatMulOp::verify() {
       return emitOpError("expect quantized operands to have same widths, got ")
              << aQuantWidth << " and " << bQuantWidth;
     }
+  } else {
+    // non-quantized element types
+    if (aElementType != bElementType) {
+      return emitOpError("expect same element type for inputs a and b, got ")
+             << aElementType << " and " << bElementType;
+    }
+  }
 
-    return success();
+  // check a_zp and b_zp
+  auto aEType = getStorageElementTypeOrSelf(aType);
+  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
+  if (aEType != aZpEType) {
+    return emitOpError("expect input a and a_zp have the same "
+                       "element type, got ")
+           << aEType << " and " << aZpEType;
   }
 
-  // non-quantized element types
-  if (aElementType != bElementType) {
-    return emitOpError("expect same element type for inputs a and b, got ")
-           << aElementType << " and " << bElementType;
+  auto bEType = getStorageElementTypeOrSelf(bType);
+  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
+  if (bEType != bZpEType) {
+    return emitOpError("expect input b and b_zp have the same "
+                       "element type, got ")
+           << bEType << " and " << bZpEType;
   }
 
+  FailureOr<int64_t> maybeAZp = getAZeroPoint();
+  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
+    return failure();
+
+  FailureOr<int64_t> maybeBZp = getBZeroPoint();
+  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
+    return failure();
+
   return success();
 }
 
@@ -1721,6 +1733,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
 ZERO_POINT_HELPER(AvgPool2dOp, Input)
 ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(MatMulOp, A)
+ZERO_POINT_HELPER(MatMulOp, B)
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..983062ffd7912 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
   addValue(op.getOutput());
 }
 
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
+  addValue(op.getA());
+  addValue(op.getB());
+  addValue(op.getAZp());
+  addValue(op.getBZp());
+  addValue(op.getOutput());
+}
+
 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 // This helper function only populates the info for the customised operands.
 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)                                   \
@@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Resize)
   POPULATE_PROFILE_INFO_CUSTOM(Select)
   POPULATE_PROFILE_INFO_CUSTOM(Rescale)
+  POPULATE_PROFILE_INFO_CUSTOM(MatMul)
 
   // Type Invariant Extension, a capability extension that is independent
   // of the data type, meaning any compatible type can be used. No type
@@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_COMMON(Cast)
   POPULATE_PROFILE_INFO_COMMON(Const)
   POPULATE_PROFILE_INFO_COMMON(ArgMax)
-  POPULATE_PROFILE_INFO_COMMON(MatMul)
   POPULATE_PROFILE_INFO_COMMON(Sub)
   POPULATE_PROFILE_INFO_COMMON(Maximum)
   POPULATE_PROFILE_INFO_COMMON(Minimum)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..341f773c79a5e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>)  -> tensor<1x5x6xf32>
+  %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x5x6xf32>
   return %0 : tensor<1x5x6xf32>
 }
 
@@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
   // CHECK: [[ONE:%.+]] = arith.constant 1
   // CHECK: [[TWO:%.+]] = arith.constant 2
   // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
-  %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+  %a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32>
   return %0 : tensor<1x5x6xi32>
 }
 
@@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>)
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
   // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> tensor<?x5x6xf32>
+  %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<?x5x3xf32>, tensor<?x3x6xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<?x5x6xf32>
   return %0 : tensor<?x5x6xf32>
 }
 
@@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
   // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32>
+  %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x5x?xf32>
   return %0 : tensor<1x5x?xf32>
 }
 
@@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
   // CHECK: %[[INIT:.+]] = tensor.empty()
   // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32>
+  %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x5x6xf32>
   return %0 : tensor<1x5x6xf32>
 }
 
@@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
   // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+  %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<?x1x1xf32>
   return %0 : tensor<?x1x1xf32>
 }
 
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 83136f613b020..14c67e670e921 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -98,14 +98,16 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
 }
 
 // CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
   // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
   // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
   // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
   // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
   // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
   %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
@@ -115,14 +117,16 @@ func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: ten
 }
 
 // CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
   // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
   // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
   // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
   // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
   %s1 = mesh.sharding @mesh_2d split_axes = [[], [1]] partial = sum [0] : !mesh.sharding
@@ -132,16 +136,18 @@ func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<
 }
 
 // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
   // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
   %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s0 annotate_for_users  : tensor<2x16x8xf32>
   // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
-  %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+  %1 = tosa.matmul %0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
   // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
   // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
   // CHECK-NEXT:  return %[[V3]]
@@ -149,8 +155,8 @@ func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<
 }
 
 // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
   // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
   %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
@@ -159,8 +165,10 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
   %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding
   %1 = mesh.shard %arg1 to %s1 annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
-  %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+  %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
   // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
   // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
   // CHECK-NEXT:  return %[[V3]]
@@ -199,14 +207,16 @@ func.func @resolve_conflicting_annotations(
 
 // https://arxiv.org/abs/2211.05102 Figure 2(a)
 // CHECK-LABEL: func.func @mlp_1d_weight_stationary
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
-func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32>
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
   %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s0  : tensor<2x4x8xf32>
   // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
   // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
+  // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-DAG:  %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users  : tensor<1xf32>
   // CHECK: %[[V0:.*]] = tosa.matmul
-  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x32xf32>
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]]  : tensor<2x4x32xf32>
   // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users  : tensor<2x4x32xf32>
   // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
@@ -215,8 +225,8 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
   // CHECK-NEXT:  %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users  : tensor<2x4x32xf32>
   // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users  : tensor<2x32x8xf32>
-  // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]]
-  %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]], %[[ZP]], %[[ZP]]
+  %3 = tosa.matmul %2, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x8xf32>
   %s4 = mesh.sharding @mesh_1d split_axes = [[], [], []] partial = sum [0] : !mesh.sharding
   %4 = mesh.shard %3 to %s4  : tensor<2x4x8xf32>
   // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], []] partial =  sum [0] : !mesh.sharding
@@ -230,8 +240,8 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
 
 // https://arxiv.org/abs/2211.05102 Figure 2(b)
 // CHECK-LABEL: func.func @mlp_2d_weight_stationary
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
-func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32>
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
   // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
   // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]]  : tensor<2x4x8xf32>
   %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
@@ -240,8 +250,10 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users  : tensor<2x4x8xf32>
   // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
   // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]]
-  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-DAG:  %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users  : tensor<1xf32>
+  // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]], %[[ZP]], %[[ZP]]
+  %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x32xf32>
   // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding
   // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]]  : tensor<2x4x32xf32>
   %s2 = mesh.sharding @mesh_3d split_axes = [[], [], [1, 2]] partial = sum [0] : !mesh.sharding
@@ -254,8 +266,8 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
   // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users  : tensor<2x4x32xf32>
   // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
   // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users  : tensor<2x32x8xf32>
-  // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]]
-  %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]], %[[ZP]], %[[ZP]]
+  %4 = tosa.matmul %3, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x8xf32>
   // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding
   // CHECK-NEXT:  %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]]  : tensor<2x4x8xf32>
   %s5 = mesh.sharding @mesh_3d split_axes = [[], [], [0]] partial = sum[1, 2] : !mesh.sharding
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 1952ad79392c7..b786264d84106 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -69,10 +69,10 @@ func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (te
 
 // -----
 // CHECK-LABEL: matmul
-func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
   // CHECK: profiles: [ [pro_int, pro_fp] ]
   // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
   return %0 : tensor<1x14x28xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 05700ca3765e4..f536444f6379e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -287,7 +287,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
 // -----
 
 func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>  
+  %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
   %1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
@@ -1612,3 +1612,43 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
   %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
+
+// -----
+// CHECK-LABEL: test_matmul_a_zp_same_element_type
+func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_b_zp_same_element_type
+func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+// expected-error at +1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_a_zp_non_zero
+func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_b_zp_non_zero
+func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bc13b614e3f9d..6d8237635d0ec 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1110,8 +1110,9 @@ func.func @test_rfft2d_tensor_size_invalid(%arg0: tensor<536870912x8x16xf32>) ->
 // -----
 
 func.func @test_matmul_tensor_size_invalid(%arg0: tensor<23178x20000x19xf32>, %arg1: tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32> {
+  %zero = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error at +1 {{'tosa.matmul' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32>
+  %0 = tosa.matmul %arg0, %arg1, %zero, %zero : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<23178x20000x28xf32>
   return %0 : tensor<23178x20000x28xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e1ac7d5f51d0e..920d66b00d544 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -145,7 +145,9 @@ func.func @test_fft2d_with_local_bound(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1
 // -----
 // CHECK-LABEL: test_matmul
 func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
   return %0 : tensor<1x14x28xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 342c57b0dd85c..d0e97e46f1f6a 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -26,9 +26,9 @@ func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %ar
 }
 
 // -----
-func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %arg2: tensor<1xf32>) -> tensor<1x14x28xf32> {
   // expected-error at +1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}}
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2: (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
   return %0 : tensor<1x14x28xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 3dd0344e3647d..28c7abdeaf7f7 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -19,9 +19,9 @@ func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %ar
 }
 
 // -----
-func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %arg2: tensor<1xf32>) -> tensor<1x14x28xf32> {
   // expected-error at +1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}}
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
   return %0 : tensor<1x14x28xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 55c5c3f6bdfb6..deede4b0afadc 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -279,8 +279,10 @@ func.func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () {
 
 // CHECK-LABEL: @test_static_matmul
 func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi32>) -> () {
-  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<?x?x?xi32>
+  // CHECK tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<2x3x5xi32>
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<?x?x?xi32>
 
   return
 }
@@ -289,8 +291,10 @@ func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi3
 
 // CHECK-LABEL: @test_dynamic_lhs_matmul
 func.func @test_dynamic_lhs_matmul(%arg0 : tensor<?x?x?xi32>, %arg1 : tensor<2x4x5xi32>) -> () {
-  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<?x?x?xi32>, tensor<2x4x5xi32>) -> tensor<2x?x5xi32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<?x?x?xi32>, tensor<2x4x5xi32>) -> tensor<?x?x?xi32>
+  // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<?x?x?xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<2x?x5xi32>
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<?x?x?xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<?x?x?xi32>
 
   return
 }
@@ -299,8 +303,10 @@ func.func @test_dynamic_lhs_matmul(%arg0 : tensor<?x?x?xi32>, %arg1 : tensor<2x4
 
 // CHECK-LABEL: @test_dynamic_rhs_matmul
 func.func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<?x?x?xi32>) -> () {
-  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<?x?x?xi32>) -> tensor<2x3x?xi32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
+  // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<?x?x?xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<2x3x?xi32>
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<?x?x?xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<?x?x?xi32>
 
   return
 }
@@ -309,8 +315,10 @@ func.func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<?x?
 
 // CHECK-LABEL: @test_dynamic_mixed_matmul
 func.func @test_dynamic_mixed_matmul(%arg0 : tensor<?x3x?xi32>, %arg1 : tensor<?x?x5xi32>) -> () {
-  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<?x3x?xi32>, tensor<?x?x5xi32>) -> tensor<?x3x5xi32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<?x3x?xi32>, tensor<?x?x5xi32>) -> tensor<?x?x?xi32>
+  // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<?x3x?xi32>, tensor<?x?x5xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<?x3x5xi32>
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<?x3x?xi32>, tensor<?x?x5xi32>, tensor<1xi32>, tensor<1xi32>)  -> tensor<?x?x?xi32>
 
   return
 }
@@ -1405,11 +1413,13 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
 
 // CHECK-LABEL: test_non_tosa_consumer_still_propagates
 func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x?xf32> {
-  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32>
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
-  %1 = arith.constant dense<[1, 1]> : tensor<2xindex>
-  %2 = tensor.reshape %0(%1) : (tensor<?x1x1xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  return %2 : tensor<?x?xf32>
+  // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x1x1xf32>
+  %0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %1 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<?x1x1xf32>
+  %3 = arith.constant dense<[1, 1]> : tensor<2xindex>
+  %4 = tensor.reshape %2(%3) : (tensor<?x1x1xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  return %4 : tensor<?x?xf32>
 }
 
 // -----



More information about the Mlir-commits mailing list