[Mlir-commits] [mlir] 852b648 - [mlir] Improvements to the 'quant' dialect (#100667)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 26 11:09:35 PDT 2024


Author: Rafael Ubal
Date: 2024-09-26T14:09:28-04:00
New Revision: 852b6486246141e44cc9f126f542a2ae0d73b3d6

URL: https://github.com/llvm/llvm-project/commit/852b6486246141e44cc9f126f542a2ae0d73b3d6
DIFF: https://github.com/llvm/llvm-project/commit/852b6486246141e44cc9f126f542a2ae0d73b3d6.diff

LOG: [mlir] Improvements to the 'quant' dialect (#100667)

Full revamp of the 'quant' dialect. This is an implementation for the
RFC at
https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942

Added: 
    mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Quant/IR/Quant.h
    mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
    mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
    mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
    mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
    mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
    mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
    mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
    mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
    mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
    mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
    mlir/test/Dialect/Quant/invalid.mlir
    mlir/test/Dialect/Quant/lower-quant-ops.mlir
    mlir/test/Dialect/Quant/ops.mlir
    mlir/test/Dialect/Quant/strip-func-quant-types.mlir

Modified: 
    mlir/include/mlir/Dialect/Quant/CMakeLists.txt
    mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
    mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/CAPI/Dialect/Quant.cpp
    mlir/lib/Dialect/Quant/CMakeLists.txt
    mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
    mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
    mlir/lib/Dialect/Quant/IR/QuantOps.cpp
    mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
    mlir/lib/Dialect/Quant/IR/TypeParser.cpp
    mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
    mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Quant/canonicalize.mlir
    mlir/test/Dialect/Quant/parse-uniform-invalid.mlir

Removed: 
    mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
    mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
    mlir/include/mlir/Dialect/Quant/QuantOps.h
    mlir/include/mlir/Dialect/Quant/QuantOps.td
    mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
    mlir/include/mlir/Dialect/Quant/QuantTypes.h
    mlir/include/mlir/Dialect/Quant/UniformSupport.h


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
index c08f399ee182d8..9f57627c321fb0 100644
--- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
@@ -1,6 +1,2 @@
-add_mlir_dialect(QuantOps quant)
-add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
-
-set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
-mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
-add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
+add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..c08f399ee182d8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(QuantOps quant)
+add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
+mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
+add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
similarity index 59%
rename from mlir/include/mlir/Dialect/Quant/QuantOps.h
rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h
index 14fb3035ab0d38..11a969a3ee5191 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -1,4 +1,4 @@
-//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
+//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_
-#define MLIR_DIALECT_QUANT_QUANTOPS_H_
+#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_
+#define MLIR_DIALECT_QUANT_IR_QUANT_H_
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -19,9 +19,19 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
 
-#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
+
+namespace mlir {
+namespace quant {
+
+class QuantizedType;
+class UniformQuantizedType;
+class UniformQuantizedPerAxisType;
+
+} // namespace quant
+} // namespace mlir
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
 
-#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_
+#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_

diff  --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
new file mode 100644
index 00000000000000..791cb9de48d058
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -0,0 +1,297 @@
+//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Quantization dialect, types, and traits.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef QUANT_BASE
+#define QUANT_BASE
+
+include "mlir/IR/OpBase.td"
+
+def Quant_Dialect : Dialect {
+  let name = "quant";
+  let description = [{
+    The `quant` dialect offers a framework for defining and manipulating
+    quantized values. Central to this framework is the `!quant.uniform` data
+    type, used to represent quantized values. This dialect also provides a
+    suite of operations to handle and convert quantized values between their
+    original floating-point representations and the optimized, lower bit-width
+    integer representations. The `quant` dialect is instrumented with
+    transformation passes to lower these operations into other core MLIR
+    dialects, while also flattening all occurrences of quantized types into
+    their integer counterparts.
+
+
+    ## The `!quant.uniform` type
+
+    The quantization process establishes a relationship between two types of
+    values: an *expressed value* and a *stored value*. The former refers to the
+    floating-point representation used in an original machine learning model,
+    capturing the precise numerical characteristics needed for accurate
+    calculations. The latter is the simplified integer representation that
+    resides in memory after quantization. The `!quant.uniform` data type
+    encodes the necessary information for (lossy) round-trip conversion between
+    an expressed and a stored value.
+
+    The `quant.uniform` type has two variants: per-layer quantization and
+    per-channel (or per-axis) quantization. In per-layer quantization, the
+    quantization information affects an entire tensor uniformly. Conversely, in
+    per-channel quantization, the data type encodes the specific tensor axis
+    that serves as the channel and includes quantization information for each
+    individual channel within the tensor. Below are the specific syntactic and
+    semantic considerations for each modality.
+
+
+    ### Per-layer quantization
+
+    This is the general syntax of the `!quant.uniform` type representing
+    per-layer quantization:
+
+    ```
+    `!quant.uniform` `<`
+      storedType (`<` storageMin `:` storageMax `>`)? `:`
+      expressedType `,`
+      scale (`:` zeroPoint)?
+    `>`
+    ```
+
+    The type contains the following parameters:
+
+    - `storedType`: Integer type of the value stored in memory. This type
+      conveys the bit width and signedness of the quantized stored value.
+      Signed integer types are represented as `'i' bitWidth` (e.g., `i8`),
+      while unsigned integer types are represented as `'u' bitWidth` (e.g.,
+      `u8`).
+
+    - `storageMin`, `storageMax`: Optional bounds for the stored value. If
+      given, they must be within the range of `storedType`. If omitted, the
+      entire range of `storedType` is allowed (e.g., `-128...127` for `i8` or
+      `0...255` for `u8`).
+
+    - `expressedType`: Floating-point type of the value expressed by this
+      quantized type (e.g., `f32`, `f80`, `bf16`, or `tf32`).
+
+    - `scale`: Floating-point value of type `expressedType` used in the
+      conversion between stored and expressed values.
+
+    - `zeroPoint`: Optional integer value of type `storageType` used in the
+      conversion between stored and expressed values. If omitted, the default
+      is 0.
+
+    Type conversions, rounding methods, and clamping actions aside, the
+    relationship between the expressed and stored values as encoded in a
+    quantized type is denoted by the following formula:
+
+    $$
+    expressedValue = (storedValue ~-~ zeroPoint) ~\times~ scale
+    $$
+
+    Operations `quant.qcast` (quantize cast) and `quant.dcast` (dequantize
+    cast) can be used to quantize a floating-point value and dequantize a
+    stored value, respectively. See the documentation for these operations for
+    details on how the quantization and dequantization processes are influenced
+    by the `!quant.uniform` type parameters.
+
+    Here are some examples of the use of `!quant.uniform` with per-layer
+    quantization:
+
+    ```
+    // An 8-bit signed integer type is used to represent a 32-bit float. No
+    // clamping information is provided, so the full [-128, 127] range is
+    // available. The scale is set to 3.0, and the zero point takes its default
+    // 0 value.
+    !quant.uniform<i8:f32, 3.0>
+
+    // A 16-bit unsigned integer type is used to represent a 32-bit float. Out
+    // of the 16 bits, only 10 are used, acoording to the 0..1023 clamping
+    // range. The type sets the scale to 1.23 and the zero point to 512.
+    !quant.uniform<u16<0:1023>:f32, 1.23:512>
+    ```
+
+    ### Per-channel quantization
+
+    The general syntax of the `!quant.uniform` type representing per-channel
+    quantization is as follows:
+
+    ```
+    `!quant.uniform` `<`
+      storedType (`<` storageMin `:` storageMax `>`)? `:`
+      expressedType `:`
+      channelAxis `,`
+      `{`
+        scale0 (`:` zeroPoint0)? `,`
+        scale1 (`:` zeroPoint1)? ...
+      '}'
+    `>`
+    ```
+
+    In this data type, there are multiple pairs of `scale` and `zeroPoint`
+    values. The `channelAxis` field represents the dimension of the containing
+    tensor acting as the channel. The size of the tensor along this dimension
+    is expected to match the number of provided `scale`-`zeroPoint` pairs, and
+    a given pair *i* applies to all elements in the tensor whose index along
+    dimension `channelAxis` is *i*. A quantized data type using per-channel
+    quantization is always expected to be contained within a tensor type.
+
+    Here are some examples:
+
+    ```
+    // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
+    // floats. Dimension 1 of the tensor acts as the channel dimension. Its
+    // size 3 matches the number of provided scale values. Tensor elemenets at
+    // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
+    // 5.0, respectively.
+    tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
+
+    // A 2D dynamically sized tensor contains 16-bit unsigned integers
+    // representing 32-bit floats. Dimension 0 of the tensor acts as the
+    // channel dimension. Since 2 scale and zero-point values are provided, the
+    // size of dimension 0 is expected to be 2 at runtime. Tensor elements
+    // [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale
+    // 3.0 and zero point 20.
+    tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
+    ```
+
+
+    ## Per-axis quantization integrity
+
+    When type `!quant.uniform` contains per-axis quantization information, the
+    rules below are enforced. These rules guarantee that the quantization
+    information encoded in the data type is applicable to the context in which
+    the quantized type is used. For efficiency, these rules are actively
+    enforced by the verifiers of `quant` dialect ops, but they must be
+    respected in any context in which the `!quant.uniform` data type is used,
+    such as the header of a `func.func` op, or the input of an arithmetic
+    operation.
+ 
+    - A quantized type with per-channel quantization information must be the
+      element type of a tensor container type, and may not occur directly as
+      the data type of a scalar value.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies per-channel quantization for a
+    // scalar type.
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>
+
+    // Correct. Type `!quant.uniform` with per-channel quantization is wrapped
+    // in a `tensor` type.
+    %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
+    ```
+
+    - If the tensor containing the `!quant.uniform` type is ranked, its rank
+      must be greater than the channel axis specified in the quantized type.
+
+    ```
+    // Incorrect. The tensor rank (2) is not greater than the channel axis in
+    // the quantized type (3).
+    %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>
+
+    // Correct. The tensor rank (2) is now greater than the channel axis (1):
+    %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
+    ```
+
+    - If the axis dimension in the containing tensor is static, its size must
+      be equal to the number of scales present in the quantized type.
+
+    ```
+    // Incorrect. The channel axis is 1, and the size of dimension 1 in the
+    // containing tensor is 3. However, there are 4 scale values present in the
+    // quantized type.
+    %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
+
+    // Correct. The quantized type now includes 3 scale values, matching the
+    // size of dimension 1 of the result tensor.
+    %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+    ```
+  }];
+  let cppNamespace = "::mlir::quant";
+  let useDefaultTypePrinterParser = 1;
+}
+
+
+//===----------------------------------------------------------------------===//
+// Type predicates
+//===----------------------------------------------------------------------===//
+
+class quant_ScalarOrTensorOf<Type etype> :
+    Type<Or<[etype.predicate, TensorOf<[etype]>.predicate]>,
+         "scalar or tensor of " # etype.summary>;
+
+def quant_QuantizedType :
+    Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "quantized type">;
+
+def quant_ScalarType :
+    Type<Or<[
+      AnySignlessInteger.predicate,
+      AnyFloat.predicate,
+      quant_QuantizedType.predicate
+    ]>,
+    "signless integer, float, or quantized scalar">;
+
+def quant_IntegerOrQuantizedType :
+    Type<Or<[
+      AnySignlessInteger.predicate,
+      quant_QuantizedType.predicate
+    ]>,
+    "signless integer or quantized type">;
+
+def quant_FloatScalarOrTensor :
+    quant_ScalarOrTensorOf<AnyFloat>;
+
+def quant_IntegerScalarOrTensor :
+    quant_ScalarOrTensorOf<AnySignlessInteger>;
+
+def quant_QuantizedScalarOrTensor :
+    quant_ScalarOrTensorOf<quant_QuantizedType>;
+
+def quant_IntegerOrQuantizedScalarOrTensor :
+    quant_ScalarOrTensorOf<quant_IntegerOrQuantizedType>;
+
+
+//===----------------------------------------------------------------------===//
+// Traits
+//===----------------------------------------------------------------------===//
+
+def quant_SameScalarOrTensorShape :
+    PredOpTrait<
+      "input and result are both scalars or both tensors with matching shape",
+      Or<[
+        And<[
+          TypeIsPred<"input", quant_ScalarType>,
+          TypeIsPred<"result", quant_ScalarType>
+        ]>,
+        And<[
+          TypeIsPred<"input", AnyUnrankedTensor>,
+          TypeIsPred<"result", AnyUnrankedTensor>
+        ]>,
+        And<[
+          TypeIsPred<"input", AnyRankedTensor>,
+          TypeIsPred<"result", AnyRankedTensor>,
+          AllShapesMatch<["input", "result"]>.predicate
+        ]>
+      ]>
+    >;
+
+def quant_IntegerAndQuantizedCombination :
+    PredOpTrait<
+      "input must be integer and result must be quantized, or vice versa",
+      Or<[
+        And<[
+          TypeIsPred<"input", quant_QuantizedScalarOrTensor>,
+          TypeIsPred<"result", quant_IntegerScalarOrTensor>
+        ]>,
+        And<[
+          TypeIsPred<"input", quant_IntegerScalarOrTensor>,
+          TypeIsPred<"result", quant_QuantizedScalarOrTensor>
+        ]>
+      ]>
+    >;
+
+#endif // QUANT_BASE

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
similarity index 100%
rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

diff  --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
new file mode 100644
index 00000000000000..6ef925146dce66
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -0,0 +1,243 @@
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef QUANT_OPS
+#define QUANT_OPS
+
+include "mlir/Dialect/Quant/IR/QuantBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<Trait> traits> :
+    Op<Quant_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization casts
+//===----------------------------------------------------------------------===//
+
+def quant_DequantizeCastOp : quant_Op<"dcast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
+  let summary = "Dequantize cast operation";
+  let description = [{
+    Convert an input quantized value into its expressed floating-point value.
+    The dequantization process consists of the following steps:
+
+    ```
+    def dequantize(quantizedValue: quantizedType) -> expressedType:
+        storedValue = reinterpretCast(quantizedValue, storageType)
+        storedValueFloat = convertIntToFloat(storedValue, expressedType)
+        zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+        expressedValue = (storedValueFloat - zeroPointFloat) * scale
+        return expressedValue
+    ```
+
+    Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained
+    from the corresponding parameters encoded in `quantizedType`. For
+    per-channel quantization, the appropriate `scale` and `zeroPoint` values
+    are used for each tensor element computation according to the channel the
+    element belongs to.
+    
+    The numerical results produced by the algorithm above may vary depending on
+    the rounding methods used by `convertIntToFloat()`, subtraction (`-`), and
+    multiplication (`*`). This operation does not define specific rounding
+    methods; instead, it is the responsibility of a transform pipeline to
+    determine which rounding method to apply when this operation is broken down
+    into lower-level dialects.
+
+    The operation must satisfy the following syntactic constraints:
+
+    - Operand `input` must be a scalar or tensor of type `!quant.uniform`.
+
+    - The result type must be a floating-point scalar or tensor.
+
+    - The `expressedType` parameter of the `!quant.uniform` type of the input
+      must match the floating-point type of the result.
+
+    - The operand and result types must be both scalars or both tensors. If
+      tensors, they must be both ranked or both unranked. If ranked, both must
+      have the same shape, including matching static and dynamic dimensions.
+
+    - If the operand uses per-channel quantization, its `!quant.uniform` type
+      must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
+
+    ```
+    // Dequantize a scalar quantized value
+    %result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
+
+    // Dequantize a dynamically shaped tensor of quantized values
+    %result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
+
+    // Dequantize an unranked tensor using per-axis quantization information
+    %result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
+    ```
+  }];
+  let arguments = (ins quant_QuantizedScalarOrTensor:$input);
+  let results = (outs quant_FloatScalarOrTensor:$result);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// Return the float type of the scalar or tensor result.
+    FloatType getFloatType();
+    
+    /// Return the quantized type of the scalar or tensor input.
+    quant::QuantizedType getQuantizedType();
+  }];
+}
+
+def quant_QuantizeCastOp : quant_Op<"qcast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
+  let summary = "Quantize cast operation";
+  let description = [{
+    Convert a floating-point value to a quantized type. The quantization
+    process consists of the following steps:
+
+    ```
+    def quantize(expressedValue: expressedType) -> quantizedType:
+        zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+        scaledValue = expressedValue / scale
+        storedValueFloat = scaledValue + zeroPointFloat
+        storedValue = convertFloatToInt(storedValueFloat, storageType)
+        storedValueClamped = clamp(storedValue, storageMin, storageMax)
+        quantizedValue = reinterpretCast(storedValueClamped, quantizedType)
+        return quantizedValue
+    ```
+
+    Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`,
+    and `zeroPoint` are obtained from the corresponding parameters encoded in
+    `quantizedType`. For per-channel quantization, the appropriate `scale` and
+    `zeroPoint` values are used for each tensor element computation according
+    to the channel the element belongs to.
+
+    The numerical results produced by the algorithm above may vary depending on
+    the rounding methods used by `convertIntToFloat()`, `convertFloatToInt()`,
+    `clamp()`, division (`/`), and addition (`+`). This operation does not
+    define specific rounding methods; instead, it is the responsibility of a
+    transform pipeline to determine which rounding method to apply when this
+    operation is broken down into lower-level dialects.
+
+    The operation must satisfy the following syntactic constraints:
+
+    - Operand `input` must be a floating-point scalar or tensor.
+
+    - The result type must be a scalar or tensor of type `!quant.uniform`.
+
+    - The `expressedType` parameter in the `!quant.uniform` type of the result
+      must match the floating-point type of the input.
+
+    - The operand and result types must be both scalars or both tensors. If
+      tensors, they must be both ranked or both unranked. If ranked, both must
+      have the same shape, including matching static and dynamic dimensions.
+
+    - If the result uses per-channel quantization, its `!quant.uniform` type
+      must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
+
+    ```
+    // Quantize a scalar floating-point value
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32, 2.0>
+
+    // Quantize a dynamically shaped tensor of quantized values
+    %result = quant.qcast %input : tensor<?xf32> to tensor<?x!quant.uniform<i8:f32, 2.0>>
+
+    // Quantize an unranked tensor using per-axis quantization information
+    %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
+    ```
+  }];
+  let arguments = (ins quant_FloatScalarOrTensor:$input);
+  let results = (outs quant_QuantizedScalarOrTensor:$result);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// Return the float type of the scalar or tensor input.
+    FloatType getFloatType();
+    
+    /// Return the quantized type of the scalar or tensor result.
+    quant::QuantizedType getQuantizedType();
+  }];
+}
+
+def quant_StorageCastOp : quant_Op<"scast", [
+    Pure,
+    quant_SameScalarOrTensorShape,
+    quant_IntegerAndQuantizedCombination]> {
+  let summary = "Storage cast operation";
+  let description = [{
+    Convert a value from a quantized type to the corresponding signless integer
+    storage type, or vice versa. This conversion simply involves a
+    reinterpretation of the input bits and does not involve any data
+    manipulation.
+
+    The following syntactic restrictions must be met:
+
+    - Operand `input` must be a scalar or tensor of a signless integer or
+      `!quant.uniform` type.
+
+    - The result must be a scalar or tensor of a signless integer or
+      `!quant.uniform` type.
+
+    - If the operand is a scalar or tensor of type integer, the result must be
+      a scalar or tensor of type `!quant.uniform`, and vice versa.
+
+    - The operand and result must be both scalars or both tensors. If tensors,
+      they must be both ranked or both unranked. If ranked, both must have the
+      same shape, including matching static and dynamic dimensions.
+
+    - The width of the `storageType` parameter of the quantized type of the
+      operand or result must match the width of the signless integer type of
+      the operand or result.
+
+    - If the operand or result uses per-channel quantization, its
+      `!quant.uniform` type must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
+
+    ```
+    // Cast a scalar quantized value into its storage type
+    %result = quant.scast %input : !quant.uniform<i8:f32, 2.0> to i8
+
+    // Cast a dynamically shaped tensor of quantized values into their storage type
+    %result = quant.scast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xi8>
+
+    // Cast an unranked tensor of signless integers into a quantized type using
+    // per-channel quantization
+    %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
+    ```
+  }];
+  let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
+  let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// Return the integer type used either in the input or the result.
+    IntegerType getIntegerType();
+    
+    /// Return the quantized type used either in the input or the result.
+    quant::QuantizedType getQuantizedType();
+  }];
+}
+
+#endif // QUANT_OPS

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
similarity index 98%
rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h
rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 57a2aa29833657..43440ba623b9c1 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
-#define MLIR_DIALECT_QUANT_QUANTTYPES_H
+#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
+#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -114,6 +114,10 @@ class QuantizedType : public Type {
   /// The maximum value that storageType can take.
   int64_t getStorageTypeMax() const;
 
+  /// Return whether the storage type has explicit min or max boundaries
+  /// 
diff erent from the minimum and maximum representable values.
+  bool hasStorageTypeBounds() const;
+
   /// Gets the integral bit width that the underlying storage type can exactly
   /// represent. For integral storage types, this will just be their width.
   unsigned getStorageTypeIntegralWidth() const;
@@ -413,4 +417,4 @@ class CalibratedQuantizedType
 } // namespace quant
 } // namespace mlir
 
-#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
+#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td
deleted file mode 100644
index 7937265ce2f209..00000000000000
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.td
+++ /dev/null
@@ -1,103 +0,0 @@
-//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This is the operation definition file for Quantization.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef DIALECT_QUANT_QUANT_OPS_
-#define DIALECT_QUANT_QUANT_OPS_
-
-include "mlir/Dialect/Quant/QuantOpsBase.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-
-//===----------------------------------------------------------------------===//
-// Base classes
-//===----------------------------------------------------------------------===//
-
-class quant_Op<string mnemonic, list<Trait> traits> :
-    Op<Quantization_Dialect, mnemonic, traits>;
-
-//===----------------------------------------------------------------------===//
-// Quantization casts
-//===----------------------------------------------------------------------===//
-
-def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
-  let summary = "convert a quantizable type to a quantized type";
-  let description = [{
-    A QuantizeCast `qcast` represents a potential type shift from a quantizable
-    type to a quantized type.
-
-    At runtime, a `qcast` will apply the transformation expressed by its
-    operand and result type. For flexibility during transformation, it is also
-    possible to have a `qcast` that performs no transformation (both its
-    operand and result type are quantizable).
-
-    A `qcast` will typically originate from either:
-      a) An expressed or implied constraint in the source dialect which signals
-         that a certain level of quantization is possible or required.
-      b) An inference made by a quantization algorithm indicating that a
-         quantized representation may be acceptable.
-
-    Especially early in transformation, it is common to have pairs of
-    `qcast` and `dcast` at points where a transition to a quantized type is
-    required. In addition, it is also common to have an identity `qcast`
-    (where the operand and result type are not quantized) at all points where
-    it is legal to use a quantized representation (but is not known to be
-    acceptable).
-  }];
-  let arguments = (ins quant_RealValueType:$arg);
-  let results = (outs quant_RealValueType:$res);
-}
-
-def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
-  let summary = "convert back from a quantized to quantizable (expressed) type operation";
-  let description = [{
-    A DequantizeCast op `dcast` represents the inverse of a `qcast`,
-    converting back from a quantized to quantizable (expressed) type.
-
-    Like `qcast`s, a `dcast` is allowed to have both its operand and result
-    as non quantized types. This facilitates transformations and marks edges
-    where the computation must be carried out in the expressed type.
-
-    Especially early in transformation, it is common to have `dcast`s on
-    all operands to ops that must operate with the expressed type (typically
-    math ops prior to lowering to target-specific, quantized kernels).
-  }];
-  let arguments = (ins quant_RealValueType:$arg);
-  let results = (outs quant_RealValueType:$res);
-}
-
-def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
-  let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
-  let description = [{
-    A StorageCast `scast` represents a cast from or to a type based on the
-    storage type and a type based on a corresponding quantized type.
-
-    This op exists to ensure type coherency for between parts of the computation
-    which are operating directly on an underlying storage type and those which
-    operate on quantized values.
-
-    Examples from storage to quantized type:
-    ```
-    i8 -> !quant<"uniform[i8:f32]{1.0}">
-    ```
-    ```
-    tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
-    ```
-    ```
-    vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
-    ```
-  }];
-  let arguments = (ins quant_RealOrStorageValueType:$arg);
-  let results = (outs quant_RealOrStorageValueType:$res);
-  let hasFolder = 1;
-}
-
-#endif // DIALECT_QUANT_QUANT_OPS_

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
deleted file mode 100644
index da822d0a61deb2..00000000000000
--- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
+++ /dev/null
@@ -1,74 +0,0 @@
-//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Predicates for types in the Quantization dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef DIALECT_QUANT_QUANT_OPS_BASE_
-#define DIALECT_QUANT_QUANT_OPS_BASE_
-
-include "mlir/IR/OpBase.td"
-
-def Quantization_Dialect : Dialect {
-  let name = "quant";
-  let cppNamespace = "::mlir::quant";
-
-  let useDefaultTypePrinterParser = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// Quantization type definitions
-//===----------------------------------------------------------------------===//
-
-class quant_TypedPrimitiveOrContainer<Type etype> :
-    Type<Or<[etype.predicate,
-                TensorOf<[etype]>.predicate,
-                VectorOf<[etype]>.predicate]>,
-         "primitive/tensor/vector of " # etype.summary>;
-
-// An implementation of QuantizedType.
-def quant_QuantizedType :
-    Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "QuantizedType">;
-
-// A primitive type that can represent a real value. This is either a
-// floating point value or a quantized type.
-def quant_RealPrimitiveType :
-    Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
-    "real valued primitive (float or quantized type)">;
-
-// A primitive type that can represent a storage value. This is either an
-// integer or quantized type.
-def quant_StoragePrimitiveType :
-    Type<Or<[AnySignlessInteger.predicate, quant_QuantizedType.predicate]>,
-    "quantized storage primitive (integer or quantized type)">;
-
-// A primitive or container of RealPrimitiveType.
-def quant_RealValueType :
-    quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
-
-// A primitive or container of StoragePrimitiveType.
-def quant_StorageValueType :
-    quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
-
-// Either a real valued or storage primitive or container type.
-def quant_RealOrStorageValueType :
-    Type<Or<[quant_RealValueType.predicate, quant_StorageValueType.predicate]>,
-    "real valued or storage primitive or container type">;
-
-// An implementation of UniformQuantizedType.
-def quant_UniformQuantizedType :
-    DialectType<Quantization_Dialect,
-                CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
-                "UniformQuantizedType">;
-
-// Predicate for detecting a container or primitive of UniformQuantizedType.
-def quant_UniformQuantizedValueType :
-    quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
-
-#endif // DIALECT_QUANT_QUANT_OPS_BASE_

diff  --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..30f7c1696bdb9b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
+add_public_tablegen_target(MLIRQuantTransformsIncGen)
+
+add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc)

diff  --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
new file mode 100644
index 00000000000000..84be2a21b34ed2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
@@ -0,0 +1,29 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns);
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
new file mode 100644
index 00000000000000..b25296d4db5a99
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -0,0 +1,49 @@
+//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
+  let summary = "Lower quant.dcast and quant.qcast ops";
+  let description = [{
+    Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
+    into other core dialects.
+
+    The lowering process generates storage type casts in the form of
+    `quant.scast` ops to act as an interface between the original quantized
+    types of operands and results and their corresponding storage types used in
+    the generated arithmetic computations.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "linalg::LinalgDialect",
+    "quant::QuantDialect",
+    "shape::ShapeDialect",
+    "tensor::TensorDialect"
+  ];
+}
+
+def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
+  let summary = "Strip quantized types from function headers";
+  let description = [{
+    Identify occurrences of function arguments using a quantized type and
+    replace them with a new value of the corresponding storage (signless
+    integer) type. For each converted argument, a `quant.scast` op is introduced
+    at the head of the function's entry block converting the new integer
+    argument into the original quantized value.
+  }];
+  let dependentDialects = [
+    "func::FuncDialect",
+    "quant::QuantDialect"
+  ];
+}
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
similarity index 93%
rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
index 367d468b2acf1a..6551efc6242a60 100644
--- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
@@ -34,10 +34,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
-#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 namespace mlir {
 namespace quant {
@@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
 } // namespace quant
 } // namespace mlir
 
-#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_

diff  --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
similarity index 97%
rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
index 4119aced4c0752..6773f45069c874 100644
--- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
@@ -6,12 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
-#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
 
 #include <utility>
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
@@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter {
 } // namespace quant
 } // namespace mlir
 
-#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 64bacd0e432fe5..67b41187e5bfb7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect {
     there will be tools to lower from the ML frameworks into TOSA.
   }];
 
-  let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"];
+  let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"];
 
   let cppNamespace = "mlir::tosa";
   let hasConstantMaterializer = 1;

diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 298c97015fe2eb..5e80745777b3b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -16,8 +16,8 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
 
 namespace mlir {
 namespace tosa {

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 73dccdb017ee14..7fd0432ddce1bb 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -65,7 +65,7 @@
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
@@ -137,7 +137,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   pdl_interp::PDLInterpDialect,
                   polynomial::PolynomialDialect,
                   ptr::PtrDialect,
-                  quant::QuantizationDialect,
+                  quant::QuantDialect,
                   ROCDL::ROCDLDialect,
                   scf::SCFDialect,
                   shape::ShapeDialect,

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 1b9c1b193ace6e..dd8b292a87344e 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -35,6 +35,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
   memref::registerMemRefPasses();
   mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
+  quant::registerQuantPasses();
   registerSCFPasses();
   registerShapePasses();
   spirv::registerSPIRVPasses();

diff  --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 0a7181d8bc17c3..c94dbb5692fdb0 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -8,12 +8,12 @@
 
 #include "mlir-c/Dialect/Quant.h"
 #include "mlir/CAPI/Registration.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 using namespace mlir;
 
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
 
 //===---------------------------------------------------------------------===//
 // QuantizedType

diff  --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt
index 037bba8dcb5c9b..31167e6af908b9 100644
--- a/mlir/lib/Dialect/Quant/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)
 add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index c0c00fb4893cb3..6a4ac310eb0524 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -9,8 +9,8 @@
 
 #include "QuantDialectBytecode.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/SmallVector.h"
@@ -31,7 +31,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader,
   return success();
 }
 
-#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc"
 
 /// This class implements the bytecode interface for the Quant dialect.
 struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
@@ -64,6 +64,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
 };
 } // namespace
 
-void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
+void quant::detail::addBytecodeInterface(QuantDialect *dialect) {
   dialect->addInterfaces<QuantDialectBytecodeInterface>();
 }

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
index 9e9cbf66d84d92..eef2b5bbefecc0 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
@@ -15,12 +15,12 @@
 #define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
 
 namespace mlir::quant {
-class QuantizationDialect;
+class QuantDialect;
 
 namespace detail {
 /// Add the interfaces necessary for encoding the quantization dialect
 /// components in bytecode.
-void addBytecodeInterface(QuantizationDialect *dialect);
+void addBytecodeInterface(QuantDialect *dialect);
 } // namespace detail
 } // namespace mlir::quant
 

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c9a6bbc9ceeea5..c584903f3a15de 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -6,44 +6,209 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
 #include "QuantDialectBytecode.h"
 #include "TypeDetail.h"
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/MathExtras.h"
-#include <numeric>
+#include "mlir/IR/TypeUtilities.h"
 
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
 
-void QuantizationDialect::initialize() {
+namespace mlir {
+namespace quant {
+
+namespace {
+
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+//   Any quantized type. Any quantized type with no per-axis quantization is
+//   ignored.
+//
+// - containerType
+//   Original input or result type of the operation using the provided quantized
+//   type. Used to ensure that the quantized type appears within a tensor and
+//   that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+                                        QuantizedType quantizedType,
+                                        Type containerType) {
+  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+  if (!quantizedPerAxisType)
+    return success();
+
+  auto tensorType = dyn_cast<TensorType>(containerType);
+  if (!tensorType)
+    return op->emitError("scalar types may not use per-axis quantization");
+
+  if (!tensorType.hasRank())
+    return success();
+
+  int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+  if (quantizedDimension >= tensorType.getRank())
+    return op->emitError("quantized dimension must be less than tensor rank");
+
+  int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
+  if (quantizedDimensionSize != ShapedType::kDynamic &&
+      quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+    return op->emitError(
+        "quantized dimension size does not match number of scales");
+
+  return success();
+}
+
+// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
+//
+// - quantizedType
+//   Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
+//   whether as a primitive type or in a tensor.
+//
+// - floatType
+//   Float type used in the input ('quant.qcast') or result ('quant.dcast'),
+//   whether as a primitive type or in a tensor.
+//
+// - containerType
+//   Type of original input or result.
+//
+LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
+                                   FloatType floatType, Type containerType) {
+  if (quantizedType.getExpressedType() != floatType)
+    return op->emitError(
+        "expressed type in quantized type expected to match float type");
+
+  // Veriy integrity of per-axis quantization information, if present.
+  return verifyPerAxisQuantization(op, quantizedType, containerType);
+}
+
+}  // namespace
+
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
+void QuantDialect::initialize() {
   addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
   addOperations<
 #define GET_OP_LIST
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
       >();
-  addBytecodeInterface(this);
+  detail::addBytecodeInterface(this);
+}
+
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DequantizeCastOp::verify() {
+  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+                              getInput().getType());
+}
+
+OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
+  // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
+  // with the value of x. Values x and y are guaranteed to be of the same type
+  // in this pattern.
+  auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
+  if (!srcQcastOp)
+    return {};
+  assert(srcQcastOp.getInput().getType() == getType());
+  return srcQcastOp.getInput();
+}
+
+FloatType DequantizeCastOp::getFloatType() {
+  return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+QuantizedType DequantizeCastOp::getQuantizedType() {
+  return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult QuantizeCastOp::verify() {
+  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+                              getInput().getType());
+}
+
+OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
+  // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
+  // with the value of x if the casts invert each other. Contrary to the folding
+  // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
+  // x and y are not guaranteed to be of the same type here, as they may use
+  // 
diff erent quantization parameters.
+  auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
+  if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
+    return {};
+  return srcDcastOp.getInput();
+}
+
+FloatType QuantizeCastOp::getFloatType() {
+  return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+QuantizedType QuantizeCastOp::getQuantizedType() {
+  return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+
+//===----------------------------------------------------------------------===//
+// StorageCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult StorageCastOp::verify() {
+  auto quantizedType = getQuantizedType();
+  auto integerType = getIntegerType();
+  if (quantizedType.getStorageType() != integerType)
+    return emitError(
+        "storage type in quantized type expected to match integer type");
+
+  // Verify integrity of per-axis quantization information, if available. While
+  // the quantization type may appear in the input or the result, their tensor
+  // shapes are guaranteed to be identical at this point.
+  return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
 }
 
 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
-  // Matches x -> [scast -> scast] -> y, replacing the second scast with the
-  // value of x if the casts invert each other.
-  auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
-  if (!srcScastOp || srcScastOp.getArg().getType() != getType())
-    return OpFoldResult();
-  return srcScastOp.getArg();
+  // Matches x -> quant.scast -> quant.scast -> y, replacing the second
+  // quant.scast with the value of x if the casts invert each other.
+  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
+    return {};
+  return srcScastOp.getInput();
+}
+
+IntegerType StorageCastOp::getIntegerType() {
+  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+  if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
+    return integerType;
+
+  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+  return cast<IntegerType>(resultScalarType);
+}
+
+QuantizedType StorageCastOp::getQuantizedType() {
+  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+  if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
+    return quantizedType;
+
+  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+  return cast<QuantizedType>(resultScalarType);
 }
 
+
+} // namespace quant
+} // namespace mlir
+
 #define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
+

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index c2ba9c04e8771d..ac01b37a553077 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
 #include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -20,12 +20,28 @@ using namespace mlir;
 using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
+namespace {
+
+// Return the minimum scale representable in a given float type
+double getMinScale(Type expressedType) {
+  auto floatType = cast<FloatType>(expressedType);
+  return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
+}
+
+// Return the maximum scale representable in a given float type
+double getMaxScale(Type expressedType) {
+  auto floatType = cast<FloatType>(expressedType);
+  return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
+}
+
+}  // namespace
+
 unsigned QuantizedType::getFlags() const {
   return static_cast<ImplType *>(impl)->flags;
 }
 
 bool QuantizedType::classof(Type type) {
-  return llvm::isa<QuantizationDialect>(type.getDialect());
+  return llvm::isa<QuantDialect>(type.getDialect());
 }
 
 LogicalResult
@@ -73,6 +89,17 @@ int64_t QuantizedType::getStorageTypeMax() const {
   return static_cast<ImplType *>(impl)->storageTypeMax;
 }
 
+bool QuantizedType::hasStorageTypeBounds() const {
+  unsigned int integralWidth = getStorageTypeIntegralWidth();
+  bool isSignedInteger = isSigned();
+  int64_t defaultIntegerMin =
+      getDefaultMinimumForInteger(isSignedInteger, integralWidth);
+  int64_t defaultIntegerMax =
+      getDefaultMaximumForInteger(isSignedInteger, integralWidth);
+  return defaultIntegerMin != getStorageTypeMin() ||
+         defaultIntegerMax != getStorageTypeMax();
+}
+
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
   // NOTE: If ever supporting non-integral storage types, some other scheme
   // for determining the width will be needed.
@@ -293,8 +320,13 @@ LogicalResult UniformQuantizedType::verifyInvariants(
     return emitError() << "expressed type must be floating point";
 
   // Verify scale.
+  double minScale = getMinScale(expressedType);
+  double maxScale = getMaxScale(expressedType);
   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
     return emitError() << "illegal scale: " << scale;
+  if (scale < minScale || scale > maxScale)
+    return emitError() << "scale out of expressed type range [" << minScale
+                       << ", " << maxScale << "]";
 
   return success();
 }
@@ -353,11 +385,20 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
                        << scales.size() << ", " << zeroPoints.size();
 
   // Verify scale.
+  double minScale = getMinScale(expressedType);
+  double maxScale = getMaxScale(expressedType);
   for (double scale : scales) {
     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
       return emitError() << "illegal scale: " << scale;
+    if (scale < minScale || scale > maxScale)
+      return emitError() << "scale out of expressed type range [" << minScale
+                         << ", " << maxScale << "]";
   }
 
+  // Verify quantized dimension.
+  if (quantizedDimension < 0)
+    return emitError() << "illegal quantized dimension: " << quantizedDimension;
+
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 926a8a0aa13d5c..851763d8942e83 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
@@ -317,7 +317,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) {
 }
 
 /// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
+Type QuantDialect::parseType(DialectAsmParser &parser) const {
   // All types start with an identifier that we switch on.
   StringRef typeNameSpelling;
   if (failed(parser.parseKeyword(&typeNameSpelling)))
@@ -346,12 +346,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   }
 
   // storageTypeMin and storageTypeMax if not default.
-  int64_t defaultIntegerMin =
-      QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
-  int64_t defaultIntegerMax =
-      QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
-  if (defaultIntegerMin != type.getStorageTypeMin() ||
-      defaultIntegerMax != type.getStorageTypeMax()) {
+  if (type.hasStorageTypeBounds()) {
     out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
         << ">";
   }
@@ -419,7 +414,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type,
 }
 
 /// Print a type registered to this dialect.
-void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
+void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
   if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
     printAnyQuantizedType(anyType, os);
   else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))

diff  --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..2fd4a41999d456
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,26 @@
+add_mlir_dialect_library(MLIRQuantTransforms
+  LowerQuantOps.cpp
+  StripFuncQuantTypes.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+
+  DEPENDS
+  MLIRQuantTransformsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRFuncDialect
+  MLIRFuncTransforms
+  MLIRIndexDialect
+  MLIRIR
+  MLIRLinalgDialect
+  MLIRLinalgUtils
+  MLIRPass
+  MLIRQuantDialect
+  MLIRShapeDialect
+  MLIRTensorDialect
+  MLIRTransforms
+  MLIRTransformUtils
+
+  )

diff  --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
new file mode 100644
index 00000000000000..4adeb9218ff8ec
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -0,0 +1,676 @@
+//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_LOWERQUANTOPS
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+// If 'inputType' is a tensor, return its element type. If it is a scalar,
+// return it as is.
+Type getScalarType(Type inputType) {
+  if (auto tensorType = dyn_cast<TensorType>(inputType))
+    return tensorType.getElementType();
+  return inputType;
+}
+
+// Return the shape of an input value as a list of attributes (static dimensions)
+// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
+// returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult>
+getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+  if (isa<TensorType>(input.getType()))
+    return tensor::getMixedSizes(builder, loc, input);
+  return {};
+}
+
+// If 'referenceType' is a scalar, return 'elementType' as is. If
+// 'referenceType' is a tensor, return another tensor with the same shape and
+// elements of type 'elementType'.
+Type getScalarOrTensorType(Type elementType, Type referenceType) {
+  if (auto tensorType = dyn_cast<TensorType>(referenceType))
+    return tensorType.clone(elementType);
+  return elementType;
+}
+
+// Return a constant with the given value. If 'referenceType' is a tensor, a
+// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
+// scalar, 'referenceShape' is ignored and a scalar constant is returned.
+Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
+                                Type referenceType,
+                                ArrayRef<OpFoldResult> referenceShape) {
+  // If the result type is a scalar, return the unmodified scalar constant.
+  auto tensorType = dyn_cast<TensorType>(referenceType);
+  if (!tensorType) {
+    assert(referenceShape.empty());
+    return scalar;
+  }
+
+  // Create tensor splat
+  auto tensorConstant =
+      builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
+  return tensorConstant;
+}
+
+// Reshape an unranked tensor into a 1D ranked tensor.
+//
+// - input
+//   Unranked tensor.
+//
+// Return values:
+//
+// - flatInput
+//   1D ranked, dynamically shaped tensor.
+//
+// - inputShape
+//   1D extent tensor containing the shape of the original unranked input.
+//
+std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
+                                              Value input) {
+  // Get unranked input shape and total size
+  auto *context = builder.getContext();
+  auto shapeType = shape::getExtentTensorType(context);
+  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+  Value inputSize = builder.create<shape::NumElementsOp>(
+      loc, builder.getIndexType(), inputShape);
+
+  // Turn input size into 1D tensor
+  auto flatShapeType = shape::getExtentTensorType(context, 1);
+  auto flatInputShape = builder.create<tensor::FromElementsOp>(
+      loc, flatShapeType, inputSize);
+
+  // Reshape input tensor into 1D
+  auto inputType = cast<UnrankedTensorType>(input.getType());
+  auto elementType = inputType.getElementType();
+  auto flatInputType =
+      RankedTensorType::get({ShapedType::kDynamic}, elementType);
+  auto flatInput = builder.create<tensor::ReshapeOp>(
+      loc, flatInputType, input, flatInputShape);
+  return std::make_pair(flatInput, inputShape);
+}
+
+// Reshape an unranked tensor into a 3D ranked tensor where the central
+// dimension of the result tensor corresponds to dimension 'axis' of the input
+// tensor.
+//
+// - input
+//   Unranked tensor.
+//
+// - axis
+//   Index of the input dimension around which other input dimiensions will be
+//   collapsed.
+//
+// - axisSize
+//   Size of input dimension 'axis'.
+//
+// Return values:
+//
+// - flatInput
+//   3D ranked tensor of shape [?, axisSize, ?].
+//
+// - inputShape
+//   1D extent tensor containing the shape of the original unranked input.
+//
+std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
+                                                        Location loc,
+                                                        Value input,
+                                                        int64_t axis,
+                                                        int64_t axisSize) {
+  // Get full tensor shape
+  auto *context = builder.getContext();
+  auto indexType = builder.getIndexType();
+  auto shapeType = shape::getExtentTensorType(context);
+  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+
+  // Get shape and sizes on left and right of axis
+  auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
+  auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
+  auto shapeLeft = builder.create<shape::SplitAtOp>(
+      loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
+      .getResult(0);
+  auto sizeLeft = builder.create<shape::NumElementsOp>(
+      loc, indexType, shapeLeft);
+  auto shapeRight = builder.create<shape::SplitAtOp>(
+      loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
+      .getResult(1);
+  auto sizeRight = builder.create<shape::NumElementsOp>(
+      loc, indexType, shapeRight);
+
+  // Compute flat input shape as a 3-element 1D tensor
+  auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
+  auto flatShapeType = shape::getExtentTensorType(context, 3);
+  auto flatInputShape = builder.create<tensor::FromElementsOp>(
+      loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
+
+  // Reshape input to 3D tensor
+  auto inputType = cast<UnrankedTensorType>(input.getType());
+  auto elementType = inputType.getElementType();
+  auto flatInputType = RankedTensorType::get(
+      {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
+  auto flatInput = builder.create<tensor::ReshapeOp>(
+      loc, flatInputType, input, flatInputShape);
+
+  return std::make_pair(flatInput, inputShape);
+}
+
+// Reshape an input tensor into its original unranked shape.
+//
+// - input
+//   Ranked tensor.
+//
+// - inputShape
+//   1D extent tensor.
+//
+Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
+                                 Value inputShape) {
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elementType = inputType.getElementType();
+  auto unrankedType = UnrankedTensorType::get(elementType);
+  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
+}
+
+// Create a tensor constant containing all scales in a per-channel quantized
+// type. Example:
+//
+//   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+//   %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
+//
+Value materializePerChannelScales(OpBuilder &builder, Location loc,
+                                  UniformQuantizedPerAxisType quantizedType) {
+  auto scales = quantizedType.getScales();
+  auto expressedType = quantizedType.getExpressedType();
+  auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
+    return builder.getFloatAttr(expressedType, scale);
+  });
+  auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
+  return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+}
+
+// Create a tensor constant containing all zero points in a per-channel
+// quantized type. Example:
+//
+//   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+//   %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
+//
+Value materializePerChannelZeroPoints(
+    OpBuilder &builder, Location loc,
+    UniformQuantizedPerAxisType quantizedType) {
+  auto zeroPoints = quantizedType.getZeroPoints();
+  auto storageType = quantizedType.getStorageType();
+  auto zeroPointAttrs = llvm::map_to_vector(
+      zeroPoints,
+      [&](int64_t zeroPoint) -> Attribute {
+        return builder.getIntegerAttr(storageType, zeroPoint);
+      });
+  auto tensorType =
+      RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
+  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
+  return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+}
+
+// Clamp the given scalar or tensor input using the storage bounds encoded in
+// the given quantized type, if present.
+//
+// - input
+//   Scalar or ranked tensor input. The element type must match the storage type
+//   of 'quantizedType'.
+//
+// - inputShape
+//   If 'input' is a tensor, combination of attributes/values representing its
+//   static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - quantizedType
+//   Per-axis or per-channel quantized type.
+Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+                          ArrayRef<OpFoldResult> inputShape,
+                          QuantizedType quantizedType) {
+  // If quantized type does not narrow down the storage type range, there is
+  // nothing to do.
+  if (!quantizedType.hasStorageTypeBounds())
+    return input;
+
+  // Materialize bounds
+  auto inputType = input.getType();
+  auto storageType = quantizedType.getStorageType();
+  auto storageMinScalar = builder.create<arith::ConstantIntOp>(
+      loc, quantizedType.getStorageTypeMin(), storageType);
+  auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
+      loc, quantizedType.getStorageTypeMax(), storageType);
+  auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
+                                              inputType, inputShape);
+  auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
+                                              inputType, inputShape);
+
+  // Clamp
+  if (quantizedType.isSigned()) {
+    input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
+    input = builder.create<arith::MinSIOp>(loc, input, storageMax);
+  } else {
+    input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
+    input = builder.create<arith::MinUIOp>(loc, input, storageMax);
+  }
+  return input;
+}
+
+// Emit op 'arith.fptosi' or 'arith.fptoui'.
+Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
+                            Type resultType, bool isSigned) {
+  if (isSigned)
+    return builder.create<arith::FPToSIOp>(loc, resultType, input);
+  return builder.create<arith::FPToUIOp>(loc, resultType, input);
+}
+
+// Emit op 'arith.sitofp' or 'arith.uitofp'.
+Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
+                            Type resultType, bool isSigned) {
+  if (isSigned)
+    return builder.create<arith::SIToFPOp>(loc, resultType, input);
+  return builder.create<arith::UIToFPOp>(loc, resultType, input);
+}
+
+// Quantize a scalar or ranked tensor value. The stored value is clamped using 
+// the storage bounds encoded in the given quantized type.
+//
+// See function 'convertRanked()' below for a description of the arguments.
+Value quantizeValue(OpBuilder &builder, Location loc, Value input,
+                    ArrayRef<OpFoldResult> inputShape, Value scale,
+                    Value zeroPoint, QuantizedType quantizedType) {
+  // Convert scale to tensor if necessary
+  auto inputType = input.getType();
+  scale = getScalarOrTensorConstant(
+      builder, loc, scale, inputType, inputShape);
+
+  // Scale input
+  auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+
+  // Skip unnecessary computations if no zero point is given
+  Value storedValueFloat = scaledValue;
+  if (!matchPattern(zeroPoint, m_Zero())) {
+    // Convert zero point to tensor if necessary
+    zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+                                          inputShape);
+
+    // Convert zero point from storage to expressed type
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+                                      scale.getType(),
+                                      quantizedType.isSigned());
+
+    // Add zero point to stored value
+    storedValueFloat =
+        builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+  }
+
+  // Convert stored value to storage type
+  auto storageScalarOrTensorType =
+      getScalarOrTensorType(quantizedType.getStorageType(), inputType);
+  auto storedValueInt = convertFloatToInteger(
+      builder, loc, storedValueFloat, storageScalarOrTensorType,
+      quantizedType.isSigned());
+
+  // Clamp stored value it if the storage type is bound
+  auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
+                                                inputShape, quantizedType);
+  return storedValueClamped;
+}
+
+// Dequantize a scalar or ranked tensor input.
+//
+// See function 'convertRanked()' below for a description of the arguments.
+Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
+                      ArrayRef<OpFoldResult> inputShape, Value scale,
+                      Value zeroPoint, QuantizedType quantizedType) {
+  // Convert scale to tensor if necessary
+  auto inputType = input.getType();
+  scale = getScalarOrTensorConstant(
+      builder, loc, scale, inputType, inputShape);
+
+  // Convert stored value to float
+  auto result = convertIntegerToFloat(
+      builder, loc, input, scale.getType(), quantizedType.isSigned());
+
+  // Skip unnecessary computations if no zero point is given
+  if (!matchPattern(zeroPoint, m_Zero())) {
+    // Convert zero point to tensor if necessary
+    zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+                                          inputShape);
+
+    // Convert zero point from storage to expressed type
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+                                      scale.getType(),
+                                      quantizedType.isSigned());
+
+    // Subtract zero point to stored value
+    result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
+  }
+
+  // Multiply by scale
+  result = builder.create<arith::MulFOp>(loc, result, scale);
+  return result;
+}
+
+// Convert a scalar or ranked tensor input with the given scale and zero point
+// values.
+//
+// - input
+//   Scalar or ranked tensor value.
+//
+// - inputShape
+//   If 'input' is a tensor, combination or attributes/values representing its
+//   static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - scale
+//   Scale as a floating-point scalar value.
+//
+// - zeroPoint
+//   Zero point as an integer scalar value.
+//
+// - quantizedType
+//   Scalar quantized type of the result ('quant.qcast') or of the input
+//   ('quant.dcast').
+//
+Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
+                    Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
+                    Value zeroPoint, QuantizedType quantizedType) {
+  if (isa<QuantizeCastOp>(op))
+    return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+                         quantizedType);
+  if (isa<DequantizeCastOp>(op))
+    return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+                           quantizedType);
+  llvm_unreachable("unexpected quant op");
+}
+
+// Convert an operation using per-layer quantization with a scalar or ranked
+// tensor input.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar or ranked tensor.
+//
+// - quantizedType
+//   Per-layer quantized type.
+//
+Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
+                            Value input, UniformQuantizedType quantizedType) {
+  // Create scale and zero point constants
+  auto expressedType = quantizedType.getExpressedType();
+  auto storageType = quantizedType.getStorageType();
+  auto scaleAttr =
+      builder.getFloatAttr(expressedType, quantizedType.getScale());
+  auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+  auto zeroPointAttr =
+      builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+  auto zeroPoint =
+      builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+  auto inputShape = getScalarOrTensorShape(builder, loc, input);
+  return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
+                       quantizedType);
+}
+
+// Convert an operation using per-layer quantization.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+//   Per-layer quantized type.
+//
+Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
+                      Value input, UniformQuantizedType quantizedType) {
+  // Flatten input if unranked
+  bool isUnranked = isa<UnrankedTensorType>(input.getType());
+  Value inputShape;
+  if (isUnranked)
+    std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+
+  // Process ranked tensor
+  auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
+
+  // Restore original shape if unranked
+  if (isUnranked)
+    result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+  return result;
+}
+
+// Convert an operation using per-channel quantization and a scalar or ranked
+// tensor as an input.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar or ranked tensor.
+//
+// - quantizedType
+//   Per-channel quantized type.
+//
+Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
+                              Value input,
+                              UniformQuantizedPerAxisType quantizedType,
+                              int64_t channelAxis) {
+  auto *context = builder.getContext();
+
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto inputRank = inputType.getRank();
+
+  auto scales = materializePerChannelScales(builder, loc, quantizedType);
+  auto zeroPoints =
+      materializePerChannelZeroPoints(builder, loc, quantizedType);
+
+  auto elementType = isa<FloatType>(inputType.getElementType())
+                         ? quantizedType.getStorageType()
+                         : quantizedType.getExpressedType();
+  auto initShape = tensor::getMixedSizes(builder, loc, input);
+  Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
+
+  SmallVector<utils::IteratorType> iteratorTypes(
+      inputRank, utils::IteratorType::parallel);
+  auto channelAxisAffineMap = AffineMap::get(
+      inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+  SmallVector<AffineMap> indexingMaps{
+    builder.getMultiDimIdentityMap(inputRank),
+    channelAxisAffineMap,
+    channelAxisAffineMap,
+    builder.getMultiDimIdentityMap(inputRank)
+  };
+  auto result = builder.create<linalg::GenericOp>(
+      loc,
+      init.getType(),  // resultType
+      ValueRange{input, scales, zeroPoints},  // inputs
+      ValueRange{init},  // outputs
+      indexingMaps,
+      iteratorTypes,
+      [&](OpBuilder& builder, Location loc, ValueRange args) {
+        assert(args.size() == 4);
+        auto input = args[0];
+        auto scale = args[1];
+        auto zeroPoint = args[2];
+
+        auto result = convertRanked(builder, loc, op, input, {}, scale,
+                                    zeroPoint, quantizedType);
+
+        builder.create<linalg::YieldOp>(loc, result);
+      })
+      .getResult(0);
+
+  return result;
+}
+
+// Convert an operation using per-channel quantization.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+//   Per-channel quantized type.
+//
+Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
+                        Value input,
+                        UniformQuantizedPerAxisType quantizedType) {
+  // Flatten unranked tensor into a 3D ranked tensor if necessary
+  bool isUnranked = isa<UnrankedTensorType>(input.getType());
+  int64_t channelAxis = quantizedType.getQuantizedDimension();
+  int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+  Value inputShape;
+  if (isUnranked) {
+    std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+        builder, loc, input, channelAxis, channelAxisSize);
+    channelAxis = 1;
+  }
+
+  // Work on a ranked tensor
+  auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
+                                        channelAxis);
+
+  // Restore original tensor shape if unranked
+  if (isUnranked)
+    result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+  return result;
+}
+
+// Convert a quantization operation.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor, or unranked tensor. The element type matches
+//   the storage type (quant.dcast) or expressed type (quant.qcast) of
+//   'quantizedType'.
+//
+// - quantizedType
+//   Per-layer or per-channel quantized type.
+//
+Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
+                       Value input, Type quantizedType) {
+  if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
+    return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
+
+  if (auto uniformQuantizedPerAxisType =
+          dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
+    return convertPerChannel(builder, loc, op, input,
+                             uniformQuantizedPerAxisType);
+
+  llvm_unreachable("unexpected quantized type");
+}
+
+// Lowering pattern for 'quant.dcast'
+struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+  using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto input = op.getInput();
+    auto quantizedType =
+        cast<QuantizedType>(getScalarType(op.getInput().getType()));
+
+    // Convert quantized input to storage type
+    auto storageScalarOrTensorType =
+        getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
+    input = rewriter.create<quant::StorageCastOp>(
+        loc, storageScalarOrTensorType, input);
+
+    auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+// Lowering pattern for 'quant.qcast'
+struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+  using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto input = op.getInput();
+    auto quantizedType = getScalarType(op.getResult().getType());
+
+    // Flatten unranked tensor input
+    auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
+
+    // Cast stored value to result quantized value
+    rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
+        op, op.getResult().getType(), result);
+    return success();
+  }
+};
+
+struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateLowerQuantOpsPatterns(patterns);
+
+    ConversionTarget target(getContext());
+    target.addLegalOp<quant::StorageCastOp>();
+    target.addIllegalDialect<quant::QuantDialect>();
+    target.addLegalDialect<
+      arith::ArithDialect,
+      linalg::LinalgDialect,
+      shape::ShapeDialect,
+      tensor::TensorDialect
+    >();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
+  patterns.add<
+    DequantizeCastOpConversion,
+    QuantizeCastOpConversion
+  >(patterns.getContext());
+}
+
+} // namespace quant
+} // namespace mlir

diff  --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
new file mode 100644
index 00000000000000..8996eff61a39c0
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -0,0 +1,114 @@
+//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Strips quantized types from function headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+class QuantizedTypeConverter : public TypeConverter {
+
+  static Type convertQuantizedType(QuantizedType quantizedType) {
+    return quantizedType.getStorageType();
+  }
+  
+  static Type convertTensorType(TensorType tensorType) {
+    if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+      return tensorType.clone(convertQuantizedType(quantizedType));
+    return tensorType;
+  }
+
+  static Value materializeConversion(OpBuilder &builder, Type type,
+                                     ValueRange inputs, Location loc) {
+    assert(inputs.size() == 1);
+    return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+  }
+
+public:
+
+  explicit QuantizedTypeConverter() {
+    addConversion([](Type type) { return type; });
+    addConversion(convertQuantizedType);
+    addConversion(convertTensorType);
+
+    addArgumentMaterialization(materializeConversion);
+    addSourceMaterialization(materializeConversion);
+    addTargetMaterialization(materializeConversion);
+  }
+};
+
+// Conversion pass
+class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+
+  // Return whether a type is considered legal when occurring in the header of
+  // a function or as an operand to a 'return' op.
+  static bool isLegalType(Type type) {
+    if (auto tensorType = dyn_cast<TensorType>(type))
+      return isLegalType(tensorType.getElementType());
+    return !isa<quant::QuantizedType>(type);
+  }
+
+public:
+
+  void runOnOperation() override {
+    
+    auto moduleOp = cast<ModuleOp>(getOperation());
+    auto* context = &getContext();
+
+    QuantizedTypeConverter typeConverter;
+    ConversionTarget target(*context);
+    RewritePatternSet patterns(context);
+
+    // Mark func.func, func.return, and func.call illegal if they contain any
+    // quantized types.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>(
+        [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
+    target.addDynamicallyLegalOp<func::CallOp>(
+        [&](func::CallOp op) { return typeConverter.isLegal(op); });
+
+    // Register conversion patterns
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
+
+    // Apply conversion
+    if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
+

diff  --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index 8c69729824691c..fb27640bfd2784 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
 
 using namespace mlir;
 using namespace mlir::quant;

diff  --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index 408701f80444a1..62c7a7128d63a6 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include <numeric>
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 03876a7c64d07c..c62942e1be78e2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -11,7 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6dce3d03066c9a..7f740be4efb4ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -14,7 +14,7 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"

diff  --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir
index 36c3eaf5e10d20..73c57e2a48212a 100644
--- a/mlir/test/Dialect/Quant/canonicalize.mlir
+++ b/mlir/test/Dialect/Quant/canonicalize.mlir
@@ -1,24 +1,124 @@
 // RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
 
+// CHECK-LABEL: @dcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+  %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+  %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
 // -----
-// CHECK-LABEL: redundant_scast
-func.func @redundant_scast() -> tensor<4xi8> {
-  // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8>
-  // CHECK-NEXT: return
-  %cst = arith.constant dense<5> : tensor<4xi8>
-  %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
-  %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<4xi8>
-  %3 = arith.addi %2, %2 : tensor<4xi8>
-  return %3 : tensor<4xi8>
+
+// CHECK-LABEL: @dcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> {
+  %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias>
+  %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+  return %1 : tensor<4xf32>
 }
 
 // -----
-// CHECK-LABEL: non_redundant_scast
-func.func @non_redundant_scast() -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>> {
-  // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8>
-  // CHECK-NEXT: scast
-  // CHECK-NEXT: return
-  %cst = arith.constant dense<5> : tensor<4xi8>
-  %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
-  return %1 : tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
+
+// CHECK-LABEL: @qcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+  %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+  %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+  return %1 : tensor<4x!qalias>
 }
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> {
+  %0 = arith.negf %arg0 : tensor<4xf32>
+  %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+  return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+  %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+  %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1>
+  return %1 : tensor<4x!qalias1>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+  %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+  %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias>
+  return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]]
+// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]]
+// CHECK: return %[[SCAST]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> {
+  %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+  %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8>
+  return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+  %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+  %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1>
+  return %1 : tensor<4x!qalias1>
+}
+

diff  --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
new file mode 100644
index 00000000000000..ba3a8e312d96e9
--- /dev/null
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func.func @dcast_invalid_input(%arg0: f32) {
+  // expected-error at +1 {{operand #0 must be scalar or tensor of quantized type}}
+  %0 = quant.dcast %arg0 : f32 to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_invalid_result(%arg0: !qalias) {
+  // expected-error at +1 {{result #0 must be scalar or tensor of floating-point}}
+  %0 = quant.dcast %arg0 : !qalias to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.dcast %arg0 : !qalias to tensor<f32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor<!qalias>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<*xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<?x3xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_float_type_mismatch(%arg0: !qalias) {
+  // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+  %0 = quant.dcast %arg0 : !qalias to f64
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_scalar(%arg0: !qalias) {
+  // expected-error at +1 {{scalar types may not use per-axis quantization}}
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) {
+  // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+  %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) {
+  // expected-error at +1 {{quantized dimension size does not match number of scales}}
+  %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32>
+  return
+}
+
+// -----
+
+func.func @qcast_invalid_input(%arg0: f32) {
+  // expected-error at +1 {{result #0 must be scalar or tensor of quantized type}}
+  %0 = quant.qcast %arg0 : f32 to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_invalid_result(%arg0: !qalias) {
+  // expected-error at +1 {{operand #0 must be scalar or tensor of floating-point}}
+  %0 = quant.qcast %arg0 : !qalias to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_scalar_tensor(%arg0: tensor<f32>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.qcast %arg0 : tensor<f32> to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor<f32>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.qcast %arg0 : tensor<f32> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<?x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_float_type_mismatch(%arg0: f64) {
+  // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+  %0 = quant.qcast %arg0 : f64 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_scalar(%arg0: f32) {
+  // expected-error at +1 {{scalar types may not use per-axis quantization}}
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) {
+  // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+  %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) {
+  // expected-error at +1 {{quantized dimension size does not match number of scales}}
+  %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_input(%arg0: si32) {
+  // expected-error at +1 {{operand #0 must be scalar or tensor of signless integer or quantized type}}
+  %0 = quant.scast %arg0 : si32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_result(%arg0: !qalias) {
+  // expected-error at +1 {{result #0 must be scalar or tensor of signless integer or quantized type}}
+  %0 = quant.scast %arg0 : !qalias to si32
+  return
+}
+
+// -----
+
+func.func @scast_both_integers(%arg0: i8) {
+  // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+  %0 = quant.scast %arg0 : i8 to i8
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_both_quantized(%arg0: !qalias) {
+  // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+  %0 = quant.scast %arg0 : !qalias to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_scalar_tensor(%arg0: tensor<i8>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.scast %arg0 : tensor<i8> to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor<i8>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.scast %arg0 : tensor<i8> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<?x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_integer_type_mismatch(%arg0: i32) {
+  // expected-error at +1 {{storage type in quantized type expected to match integer type}}
+  %0 = quant.scast %arg0 : i32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_scalar(%arg0: i8) {
+  // expected-error at +1 {{scalar types may not use per-axis quantization}}
+  %0 = quant.scast %arg0 : i8 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) {
+  // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+  %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) {
+  // expected-error at +1 {{quantized dimension size does not match number of scales}}
+  %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias>
+  return
+}
+

diff  --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
new file mode 100644
index 00000000000000..6bba9f5c037727
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -0,0 +1,511 @@
+// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @dcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<u8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<i8>
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<i8> to tensor<f32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<f32>
+// CHECK: return %[[EXPRESSED]] : tensor<f32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_0d(%arg0: tensor<!qalias>) -> tensor<f32> {
+  %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<3x?x5xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> {
+  %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32>
+  return %0 : tensor<3x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<*xi8>
+// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor<?xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<?xi8> to tensor<?xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>> to tensor<4x?x?x5xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK:   linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<4x?x?x5xf32>
+// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> {
+  %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32>
+  return %0 : tensor<4x?x?x5xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>> to tensor<*xi8>
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor<?x3x?xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?x3x?xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor<?x3x?xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xi8>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK:   linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<?x3x?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform<i8:f32, 2.000000e+00:10>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8:f32, 2.000000e+00:10>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<i8<-5:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<u8<2:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
+  %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
+  return %0 : tensor<!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias>
+  return %0 : tensor<3x?x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32>
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8>
+
+// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8>
+// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32, 2.0:10>
+func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias>
+  return %0 : tensor<3x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor<?xf32>
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor<?xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<?xf32> to tensor<?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor<?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return %0 : tensor<*x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32>
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<4x?x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias>
+  return %0 : tensor<4x?x?x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8>
+
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK:   %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK:   %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8
+// CHECK:   %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8
+// CHECK:   linalg.yield %[[STORED_CLAMPED]] : i8
+// CHECK: } -> tensor<4x2x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
+func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias>
+  return %0 : tensor<4x2x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index
+// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x3x?xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor<?x3x?xf32>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor<?x3x?xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xf32>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<?x3x?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return %0 : tensor<*x!qalias>
+}
+

diff  --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
new file mode 100644
index 00000000000000..4abc5830d081e1
--- /dev/null
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -0,0 +1,151 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_scalar(%arg0: !qalias) {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_unranked(%arg0: tensor<*x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_dynamic(%arg0: tensor<?x?x?x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<?x?x?x!qalias> to tensor<?x?x?xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_scalar(%arg0: f32) {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) {
+  %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_unranked(%arg0: tensor<*xf32>) {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) {
+  %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_dynamic(%arg0: tensor<?x?x?xf32>) {
+  %0 = quant.qcast %arg0 : tensor<?x?x?xf32> to tensor<?x?x?x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_scalar(%arg0: i8) {
+  %0 = quant.scast %arg0 : i8 to !qalias
+  %1 = quant.scast %0 : !qalias to i8
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) {
+  %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias>
+  %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_unranked(%arg0: tensor<*xi8>) {
+  %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+  %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) {
+  %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias>
+  %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_dynamic(%arg0: tensor<?x?x?xi8>) {
+  %0 = quant.scast %arg0 : tensor<?x?x?xi8> to tensor<?x?x?x!qalias>
+  %1 = quant.scast %0 : tensor<?x?x?x!qalias> to tensor<?x?x?xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) {
+  %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+  %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+  return
+}
+
+

diff  --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index a82e8efdb1a3c3..7613a344cf2b8f 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -120,3 +120,28 @@
 // provided.
 // expected-error at +1 {{expected floating point literal}}
 !qalias = !quant.uniform<i8<-4:3>:f32, {2.000000e+02,-19.987200e-01:1}>
+
+// -----
+// Illegal negative axis in per-axis quantization
+// expected-error at +1 {{illegal quantized dimension: -1}}
+!qalias = !quant.uniform<i8:f32:-1, {2.0,3.0:1}>
+
+// -----
+// Scale f16 underflow
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16, 5.8e-8>
+
+// -----
+// Scale f16 overflow
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16, 6.6e4>
+
+// -----
+// Scale f16 underflow in per-axis quantization
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16:1, {2.0,5.8e-8}>
+
+// -----
+// Scale f16 overflow in per-axis quantization
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16:1, {2.0,6.6e4}>

diff  --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
new file mode 100644
index 00000000000000..e5f0d4921bed3e
--- /dev/null
+++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @strip_operands
+// CHECK-SAME: %[[ARG_0:.*]]: i8
+// CHECK-SAME: %[[ARG_1:.*]]: i16
+// CHECK-SAME: %[[ARG_2:.*]]: f32
+
+// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}>
+// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}>
+
+// CHECK: "test.custom_op"(%[[ARG_1_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_0_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_2]])
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) {
+  "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias>
+  "test.custom_op"(%arg1) : (!qalias1) -> tensor<?x!qalias1>
+  "test.custom_op"(%arg2) : (f32) -> tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_results
+// CHECK-SAME: tensor<4xi8>, tensor<?xi16>, tensor<*xi8>, tensor<4xf32>
+
+// CHECK: %[[RESULT_0:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[RESULT_1:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULT_2:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8>
+
+// CHECK: %[[RESULT_3:.*]] = "test.custom_op"()
+
+// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]]
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_results() -> (tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>) {
+  %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+  %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+  %2 = "test.custom_op"() : () -> tensor<*x!qalias>
+  %3 = "test.custom_op"() : () -> tensor<4xf32>
+  return %0, %1, %2, %3 : tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @callee
+// CHECK-SAME: (tensor<4xi8>, tensor<?xi16>) -> (tensor<*xi8>, tensor<4xf32>)
+
+// CHECK-LABEL: @strip_call
+
+// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]])
+
+// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>>
+// CHECK: "test.custom_op"(%[[RESULT_0_CAST]])
+
+// CHECK: "test.custom_op"(%[[RESULTS]]#1)
+
+// CHECK: return
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func private @callee(tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+
+func.func @strip_call() {
+  %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+  %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+  %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+  "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> ()
+  "test.custom_op"(%2#1) : (tensor<4xf32>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list