[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)

Rafael Ubal llvmlistbot at llvm.org
Thu Jul 25 15:47:19 PDT 2024


https://github.com/rafaelubalmw created https://github.com/llvm/llvm-project/pull/100667

Full revamp of the 'quant' dialect. This is an implementation for the RFC at https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942

>From 4de6f81e366e04b51b1ad1a22e911b5412302ec6 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 8 Jul 2024 17:58:38 -0400
Subject: [PATCH 01/18] File system restructure for 'quant' dialect

---
 .../include/mlir/Dialect/Quant/CMakeLists.txt |  8 ++----
 .../mlir/Dialect/Quant/IR/CMakeLists.txt      |  6 +++++
 .../Dialect/Quant/{QuantOps.h => IR/Quant.h}  | 12 ++++-----
 .../{QuantOpsBase.td => IR/QuantBase.td}      |  8 +++---
 .../Quant/{ => IR}/QuantDialectBytecode.td    |  0
 .../mlir/Dialect/Quant/{ => IR}/QuantOps.td   |  8 +++---
 .../mlir/Dialect/Quant/{ => IR}/QuantTypes.h  |  6 ++---
 .../Dialect/Quant/Transforms/CMakeLists.txt   |  5 ++++
 .../mlir/Dialect/Quant/Transforms/Passes.h    | 27 +++++++++++++++++++
 .../mlir/Dialect/Quant/Transforms/Passes.td   | 26 ++++++++++++++++++
 .../Quant/{ => Utils}/FakeQuantSupport.h      |  8 +++---
 .../Quant/{ => Utils}/UniformSupport.h        |  8 +++---
 .../mlir/Dialect/Tosa/Utils/QuantUtils.h      |  4 +--
 mlir/include/mlir/InitAllDialects.h           |  2 +-
 mlir/lib/CAPI/Dialect/Quant.cpp               |  4 +--
 .../Dialect/Quant/IR/QuantDialectBytecode.cpp |  6 ++---
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        | 10 +++----
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      |  4 +--
 mlir/lib/Dialect/Quant/IR/TypeParser.cpp      |  4 +--
 .../Dialect/Quant/Utils/FakeQuantSupport.cpp  |  4 +--
 .../Dialect/Quant/Utils/UniformSupport.cpp    |  2 +-
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  2 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |  2 +-
 23 files changed, 113 insertions(+), 53 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
 rename mlir/include/mlir/Dialect/Quant/{QuantOps.h => IR/Quant.h} (69%)
 rename mlir/include/mlir/Dialect/Quant/{QuantOpsBase.td => IR/QuantBase.td} (93%)
 rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantDialectBytecode.td (100%)
 rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantOps.td (96%)
 rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantTypes.h (99%)
 create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
 rename mlir/include/mlir/Dialect/Quant/{ => Utils}/FakeQuantSupport.h (93%)
 rename mlir/include/mlir/Dialect/Quant/{ => Utils}/UniformSupport.h (97%)

diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
index c08f399ee182d..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
@@ -1,6 +1,2 @@
-add_mlir_dialect(QuantOps quant)
-add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
-
-set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
-mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
-add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..c08f399ee182d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(QuantOps quant)
+add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
+mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
+add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
similarity index 69%
rename from mlir/include/mlir/Dialect/Quant/QuantOps.h
rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h
index 14fb3035ab0d3..a703612d6b489 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -1,4 +1,4 @@
-//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
+//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_
-#define MLIR_DIALECT_QUANT_QUANTOPS_H_
+#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_
+#define MLIR_DIALECT_QUANT_IR_QUANT_H_
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -19,9 +19,9 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
 
-#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
 
-#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_
+#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
similarity index 93%
rename from mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index da822d0a61deb..dadca06091b1e 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -1,4 +1,4 @@
-//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
+//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,8 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef DIALECT_QUANT_QUANT_OPS_BASE_
-#define DIALECT_QUANT_QUANT_OPS_BASE_
+#ifndef QUANT_BASE
+#define QUANT_BASE
 
 include "mlir/IR/OpBase.td"
 
@@ -71,4 +71,4 @@ def quant_UniformQuantizedType :
 def quant_UniformQuantizedValueType :
     quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
 
-#endif // DIALECT_QUANT_QUANT_OPS_BASE_
+#endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
similarity index 100%
rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
similarity index 96%
rename from mlir/include/mlir/Dialect/Quant/QuantOps.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 7937265ce2f20..a3a0ff1608a66 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -10,10 +10,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef DIALECT_QUANT_QUANT_OPS_
-#define DIALECT_QUANT_QUANT_OPS_
+#ifndef QUANT_OPS
+#define QUANT_OPS
 
-include "mlir/Dialect/Quant/QuantOpsBase.td"
+include "mlir/Dialect/Quant/IR/QuantBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -100,4 +100,4 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
   let hasFolder = 1;
 }
 
-#endif // DIALECT_QUANT_QUANT_OPS_
+#endif // QUANT_OPS
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
similarity index 99%
rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h
rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index de5aed0a91a20..c020e1b46ad4e 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
-#define MLIR_DIALECT_QUANT_QUANTTYPES_H
+#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
+#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -412,4 +412,4 @@ class CalibratedQuantizedType
 } // namespace quant
 } // namespace mlir
 
-#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
+#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..30f7c1696bdb9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
+add_public_tablegen_target(MLIRQuantTransformsIncGen)
+
+add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
new file mode 100644
index 0000000000000..0b7378651afa1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
@@ -0,0 +1,27 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
new file mode 100644
index 0000000000000..f511c90ec6931
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -0,0 +1,26 @@
+//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> {
+  let summary = "Lower 'quant.dcast' and 'quant.qcast' ops.";
+  let description = [{
+    Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
+    into other core dialects.
+
+    The lowering process generates storage type casts in the form of
+    `quant.scast` ops to convert operands and results from quantized types to
+    the corresponding storage type, or vice versa.
+  let dependentDialects = ["quant::QuantDialect"];
+}
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
similarity index 93%
rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
index 367d468b2acf1..6551efc6242a6 100644
--- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
@@ -34,10 +34,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
-#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 namespace mlir {
 namespace quant {
@@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
 } // namespace quant
 } // namespace mlir
 
-#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
similarity index 97%
rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
index 4119aced4c075..6773f45069c87 100644
--- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
@@ -6,12 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
-#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
 
 #include <utility>
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
@@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter {
 } // namespace quant
 } // namespace mlir
 
-#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 298c97015fe2e..5e80745777b3b 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -16,8 +16,8 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
 
 namespace mlir {
 namespace tosa {
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 549c26c72d8a1..75e62cda90d45 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -64,7 +64,7 @@
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 0a7181d8bc17c..b30d1de73288c 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -8,8 +8,8 @@
 
 #include "mlir-c/Dialect/Quant.h"
 #include "mlir/CAPI/Registration.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 using namespace mlir;
 
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index c0c00fb4893cb..0f4b755367495 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -9,8 +9,8 @@
 
 #include "QuantDialectBytecode.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/SmallVector.h"
@@ -31,7 +31,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader,
   return success();
 }
 
-#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc"
 
 /// This class implements the bytecode interface for the Quant dialect.
 struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c9a6bbc9ceeea..fa9725c23d643 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
 #include "QuantDialectBytecode.h"
 #include "TypeDetail.h"
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
@@ -24,14 +24,14 @@ using namespace mlir;
 using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
 void QuantizationDialect::initialize() {
   addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
   addOperations<
 #define GET_OP_LIST
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
       >();
   addBytecodeInterface(this);
 }
@@ -46,4 +46,4 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
 }
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be..15cde77e40afb 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
 #include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 926a8a0aa13d5..c882a616f397c 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index 8c69729824691..fb27640bfd278 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
 
 using namespace mlir;
 using namespace mlir::quant;
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index 408701f80444a..62c7a7128d63a 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include <numeric>
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8687be075ea67..5d082ea9b1010 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -11,7 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4337787e4aead..d3aff36a763b6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -14,7 +14,7 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"

>From d6e9adcc3d8859f9dbedbf9360d59cc91a99b9ed Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 8 Jul 2024 19:26:06 -0400
Subject: [PATCH 02/18] Successfully compiled empty pass '-lower-quant-ops'

---
 .../mlir/Dialect/Quant/IR/QuantBase.td        |  4 +--
 .../include/mlir/Dialect/Quant/IR/QuantOps.td |  2 +-
 .../mlir/Dialect/Quant/Transforms/Passes.td   |  5 +--
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |  2 +-
 mlir/include/mlir/InitAllDialects.h           |  2 +-
 mlir/include/mlir/InitAllPasses.h             |  2 ++
 mlir/lib/CAPI/Dialect/Quant.cpp               |  2 +-
 mlir/lib/Dialect/Quant/CMakeLists.txt         |  1 +
 .../Dialect/Quant/IR/QuantDialectBytecode.cpp |  2 +-
 .../Dialect/Quant/IR/QuantDialectBytecode.h   |  4 +--
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        |  2 +-
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      |  2 +-
 mlir/lib/Dialect/Quant/IR/TypeParser.cpp      |  4 +--
 .../Dialect/Quant/Transforms/CMakeLists.txt   | 15 ++++++++
 .../Quant/Transforms/LowerQuantOps.cpp        | 34 +++++++++++++++++++
 15 files changed, 68 insertions(+), 15 deletions(-)
 create mode 100644 mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp

diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index dadca06091b1e..e465d855c1986 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -15,7 +15,7 @@
 
 include "mlir/IR/OpBase.td"
 
-def Quantization_Dialect : Dialect {
+def Quant_Dialect : Dialect {
   let name = "quant";
   let cppNamespace = "::mlir::quant";
 
@@ -63,7 +63,7 @@ def quant_RealOrStorageValueType :
 
 // An implementation of UniformQuantizedType.
 def quant_UniformQuantizedType :
-    DialectType<Quantization_Dialect,
+    DialectType<Quant_Dialect,
                 CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
                 "UniformQuantizedType">;
 
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index a3a0ff1608a66..ba282d50328da 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 //===----------------------------------------------------------------------===//
 
 class quant_Op<string mnemonic, list<Trait> traits> :
-    Op<Quantization_Dialect, mnemonic, traits>;
+    Op<Quant_Dialect, mnemonic, traits>;
 
 //===----------------------------------------------------------------------===//
 // Quantization casts
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index f511c90ec6931..e43062a98b1ea 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -11,7 +11,7 @@
 
 include "mlir/Pass/PassBase.td"
 
-def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> {
+def LowerQuantOps : Pass<"lower-quant-ops"> {
   let summary = "Lower 'quant.dcast' and 'quant.qcast' ops.";
   let description = [{
     Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
@@ -20,7 +20,8 @@ def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> {
     The lowering process generates storage type casts in the form of
     `quant.scast` ops to convert operands and results from quantized types to
     the corresponding storage type, or vice versa.
-  let dependentDialects = ["quant::QuantDialect"];
+  }];
+  let dependentDialects = ["::mlir::quant::QuantDialect"];
 }
 
 #endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d2..df91ba51a0594 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect {
     there will be tools to lower from the ML frameworks into TOSA.
   }];
 
-  let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"];
+  let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"];
 
   let cppNamespace = "mlir::tosa";
   let hasConstantMaterializer = 1;
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 75e62cda90d45..08de36fe21db0 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -136,7 +136,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   pdl_interp::PDLInterpDialect,
                   polynomial::PolynomialDialect,
                   ptr::PtrDialect,
-                  quant::QuantizationDialect,
+                  quant::QuantDialect,
                   ROCDL::ROCDLDialect,
                   scf::SCFDialect,
                   shape::ShapeDialect,
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 1b9c1b193ace6..dd8b292a87344 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -35,6 +35,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
   memref::registerMemRefPasses();
   mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
+  quant::registerQuantPasses();
   registerSCFPasses();
   registerShapePasses();
   spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index b30d1de73288c..c94dbb5692fdb 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -13,7 +13,7 @@
 
 using namespace mlir;
 
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
 
 //===---------------------------------------------------------------------===//
 // QuantizedType
diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt
index 037bba8dcb5c9..31167e6af908b 100644
--- a/mlir/lib/Dialect/Quant/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index 0f4b755367495..6a4ac310eb052 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -64,6 +64,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
 };
 } // namespace
 
-void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
+void quant::detail::addBytecodeInterface(QuantDialect *dialect) {
   dialect->addInterfaces<QuantDialectBytecodeInterface>();
 }
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
index 9e9cbf66d84d9..eef2b5bbefecc 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
@@ -15,12 +15,12 @@
 #define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
 
 namespace mlir::quant {
-class QuantizationDialect;
+class QuantDialect;
 
 namespace detail {
 /// Add the interfaces necessary for encoding the quantization dialect
 /// components in bytecode.
-void addBytecodeInterface(QuantizationDialect *dialect);
+void addBytecodeInterface(QuantDialect *dialect);
 } // namespace detail
 } // namespace mlir::quant
 
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index fa9725c23d643..49c05aa7f98d3 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -26,7 +26,7 @@ using namespace mlir::quant::detail;
 
 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
-void QuantizationDialect::initialize() {
+void QuantDialect::initialize() {
   addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
   addOperations<
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 15cde77e40afb..a4829d472ecad 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -25,7 +25,7 @@ unsigned QuantizedType::getFlags() const {
 }
 
 bool QuantizedType::classof(Type type) {
-  return llvm::isa<QuantizationDialect>(type.getDialect());
+  return llvm::isa<QuantDialect>(type.getDialect());
 }
 
 LogicalResult
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index c882a616f397c..bf0f775146c1a 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -317,7 +317,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) {
 }
 
 /// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
+Type QuantDialect::parseType(DialectAsmParser &parser) const {
   // All types start with an identifier that we switch on.
   StringRef typeNameSpelling;
   if (failed(parser.parseKeyword(&typeNameSpelling)))
@@ -419,7 +419,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type,
 }
 
 /// Print a type registered to this dialect.
-void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
+void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
   if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
     printAnyQuantizedType(anyType, os);
   else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..2daea7750cfe3
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRQuantTransforms
+  LowerQuantOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+
+  DEPENDS
+  MLIRQuantTransformsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRQuantDialect
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
new file mode 100644
index 0000000000000..72f89326f555b
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -0,0 +1,34 @@
+//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_LOWERQUANTOPS
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+  }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir

>From 7e0a7b9edb810dc80eba435cbdabf2dd1e53d9b0 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 10 Jul 2024 19:51:20 -0400
Subject: [PATCH 03/18] Lowering for 'quant.qcast' with per-layer quantization
 for ranked and unranked tensors

---
 mlir/include/mlir/Dialect/Quant/IR/Quant.h    |   8 +
 .../include/mlir/Dialect/Quant/IR/QuantOps.td |  46 ++--
 .../mlir/Dialect/Quant/IR/QuantTypes.h        |   4 +
 .../mlir/Dialect/Quant/Transforms/Passes.h    |   2 +
 .../mlir/Dialect/Quant/Transforms/Passes.td   |  17 +-
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        |  53 +++-
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      |  11 +
 mlir/lib/Dialect/Quant/IR/TypeParser.cpp      |   7 +-
 .../Quant/Transforms/LowerQuantOps.cpp        | 253 +++++++++++++++++-
 9 files changed, 363 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
index a703612d6b489..c5ca88ec69795 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -21,6 +21,14 @@
 
 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
 
+namespace mlir {
+namespace quant {
+
+class UniformQuantizedType;
+
+} // namespace quant
+} // namespace mlir
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index ba282d50328da..48e2496203ff0 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -28,6 +28,24 @@ class quant_Op<string mnemonic, list<Trait> traits> :
 // Quantization casts
 //===----------------------------------------------------------------------===//
 
+def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
+  let summary = "convert back from a quantized to quantizable (expressed) type operation";
+  let description = [{
+    A DequantizeCast op `dcast` represents the inverse of a `qcast`,
+    converting back from a quantized to quantizable (expressed) type.
+
+    Like `qcast`s, a `dcast` is allowed to have both its operand and result
+    as non quantized types. This facilitates transformations and marks edges
+    where the computation must be carried out in the expressed type.
+
+    Especially early in transformation, it is common to have `dcast`s on
+    all operands to ops that must operate with the expressed type (typically
+    math ops prior to lowering to target-specific, quantized kernels).
+  }];
+  let arguments = (ins quant_RealValueType:$input);
+  let results = (outs quant_RealValueType:$result);
+}
+
 def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
   let summary = "convert a quantizable type to a quantized type";
   let description = [{
@@ -52,26 +70,18 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
     it is legal to use a quantized representation (but is not known to be
     acceptable).
   }];
-  let arguments = (ins quant_RealValueType:$arg);
-  let results = (outs quant_RealValueType:$res);
-}
+  let arguments = (ins quant_RealValueType:$input);
+  let results = (outs quant_RealValueType:$result);
 
-def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
-  let summary = "convert back from a quantized to quantizable (expressed) type operation";
-  let description = [{
-    A DequantizeCast op `dcast` represents the inverse of a `qcast`,
-    converting back from a quantized to quantizable (expressed) type.
+  let extraClassDeclaration = [{
 
-    Like `qcast`s, a `dcast` is allowed to have both its operand and result
-    as non quantized types. This facilitates transformations and marks edges
-    where the computation must be carried out in the expressed type.
+    /// Return the primitive (scalar or tensor element) type of the float input.
+    FloatType getFloatType();
 
-    Especially early in transformation, it is common to have `dcast`s on
-    all operands to ops that must operate with the expressed type (typically
-    math ops prior to lowering to target-specific, quantized kernels).
+    /// Return the primitive (scalar or tensor element) type of the quantized
+    /// result.
+    quant::UniformQuantizedType getQuantizedType();
   }];
-  let arguments = (ins quant_RealValueType:$arg);
-  let results = (outs quant_RealValueType:$res);
 }
 
 def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
@@ -95,8 +105,8 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
     vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
     ```
   }];
-  let arguments = (ins quant_RealOrStorageValueType:$arg);
-  let results = (outs quant_RealOrStorageValueType:$res);
+  let arguments = (ins quant_RealOrStorageValueType:$input);
+  let results = (outs quant_RealOrStorageValueType:$result);
   let hasFolder = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index c020e1b46ad4e..d4c9e7f9286a6 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -114,6 +114,10 @@ class QuantizedType : public Type {
   /// The maximum value that storageType can take.
   int64_t getStorageTypeMax() const;
 
+  /// Return whether the storage type has explicit min or max boundaries
+  /// different from the minimum and maximum representable values.
+  bool hasStorageTypeBounds() const;
+
   /// Gets the integral bit width that the underlying storage type can exactly
   /// represent. For integral storage types, this will just be their width.
   unsigned getStorageTypeIntegralWidth() const;
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
index 0b7378651afa1..84be2a21b34ed 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
@@ -21,6 +21,8 @@ namespace quant {
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
 
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace quant
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index e43062a98b1ea..19aa37f653c01 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -11,17 +11,24 @@
 
 include "mlir/Pass/PassBase.td"
 
-def LowerQuantOps : Pass<"lower-quant-ops"> {
-  let summary = "Lower 'quant.dcast' and 'quant.qcast' ops.";
+def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
+  let summary = "Lower quant.dcast and quant.qcast ops";
   let description = [{
     Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
     into other core dialects.
 
     The lowering process generates storage type casts in the form of
-    `quant.scast` ops to convert operands and results from quantized types to
-    the corresponding storage type, or vice versa.
+    `quant.scast` ops to act as an interface between the original quantized
+    types of operands and results and their corresponding storage types used in
+    the generated arithmetic computations.
   }];
-  let dependentDialects = ["::mlir::quant::QuantDialect"];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "linalg::LinalgDialect",
+    "quant::QuantDialect",
+    "scf::SCFDialect",
+    "tensor::TensorDialect"
+  ];
 }
 
 #endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 49c05aa7f98d3..2eafa348f6906 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -20,12 +20,27 @@
 #include "llvm/Support/MathExtras.h"
 #include <numeric>
 
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
-
 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
+
+namespace mlir {
+namespace quant {
+
+namespace {
+
+Type getPrimitiveType(Type ty) {
+  if (auto tensorType = dyn_cast<TensorType>(ty))
+    return tensorType.getElementType();
+  return ty;
+}
+
+} // namespace
+
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
 void QuantDialect::initialize() {
   addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
@@ -33,17 +48,39 @@ void QuantDialect::initialize() {
 #define GET_OP_LIST
 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
       >();
-  addBytecodeInterface(this);
+  detail::addBytecodeInterface(this);
 }
 
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+FloatType QuantizeCastOp::getFloatType() {
+  return cast<FloatType>(getPrimitiveType(getInput().getType()));
+}
+
+UniformQuantizedType QuantizeCastOp::getQuantizedType() {
+  return cast<UniformQuantizedType>(getPrimitiveType(getResult().getType()));
+}
+
+
+//===----------------------------------------------------------------------===//
+// StorageCastOp
+//===----------------------------------------------------------------------===//
+
 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
   // Matches x -> [scast -> scast] -> y, replacing the second scast with the
   // value of x if the casts invert each other.
-  auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
-  if (!srcScastOp || srcScastOp.getArg().getType() != getType())
+  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
     return OpFoldResult();
-  return srcScastOp.getArg();
+  return srcScastOp.getInput();
 }
 
+} // namespace quant
+} // namespace mlir
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
+
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index a4829d472ecad..2038a86bec8d6 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -72,6 +72,17 @@ int64_t QuantizedType::getStorageTypeMax() const {
   return static_cast<ImplType *>(impl)->storageTypeMax;
 }
 
+bool QuantizedType::hasStorageTypeBounds() const {
+  unsigned int integralWidth = getStorageTypeIntegralWidth();
+  bool isSignedInteger = isSigned();
+  int64_t defaultIntegerMin =
+      getDefaultMinimumForInteger(isSignedInteger, integralWidth);
+  int64_t defaultIntegerMax =
+      getDefaultMaximumForInteger(isSignedInteger, integralWidth);
+  return defaultIntegerMin != getStorageTypeMin() ||
+         defaultIntegerMax != getStorageTypeMax();
+}
+
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
   // NOTE: If ever supporting non-integral storage types, some other scheme
   // for determining the width will be needed.
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index bf0f775146c1a..851763d8942e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -346,12 +346,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   }
 
   // storageTypeMin and storageTypeMax if not default.
-  int64_t defaultIntegerMin =
-      QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
-  int64_t defaultIntegerMax =
-      QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
-  if (defaultIntegerMin != type.getStorageTypeMin() ||
-      defaultIntegerMax != type.getStorageTypeMax()) {
+  if (type.hasStorageTypeBounds()) {
     out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
         << ">";
   }
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 72f89326f555b..1067dc4f950d3 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -10,9 +10,16 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 namespace quant {
@@ -22,13 +29,257 @@ namespace quant {
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+public:
+  using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return success();
+  }
+};
+
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+// If 'containerType' is a tensor, return its element type. If it is a scalar,
+// return it as is.
+Type getScalarType(Type containerType) {
+  if (auto tensorType = dyn_cast<TensorType>(containerType))
+    return tensorType.getElementType();
+  return containerType;
+}
+
+// Return the shape of a container as a combination of attributes (static
+// dimensions) and values (dynamic dimensions). If 'container' is a scalar,
+// an empty list is returned. If 'container' is a tensor, its shape is returned.
+SmallVector<OpFoldResult> getContainerShape(OpBuilder &builder, Location loc,
+                                            Value container) {
+  if (isa<TensorType>(container.getType()))
+    return tensor::getMixedSizes(builder, loc, container);
+  return {};
+}
+
+// Clone the given 'containerType' with the new given 'elementType'. If
+// 'containerType' is a scalar type, there is nothing to clone, and
+// 'elementType' itself is returned. If 'constainerType' is a tensor, its
+// shape is cloned but the new element type is used.
+Type cloneContainerType(Type containerType, Type elementType) {
+  if (auto tensorType = dyn_cast<TensorType>(containerType))
+    return tensorType.clone(elementType);
+  return elementType;
+}
+
+// Get a scalar or tensor constant containing the value given in 'attr'.
+// If 'containerType' is a scalar, a scalar constant is returned. If
+// 'containerType' is a tensor, a tensor splat of shape 'containerShape' is
+// returned.
+Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
+                           Type containerType,
+                           ArrayRef<OpFoldResult> containerShape) {
+  // A statically shaped tensor can be created with 'arith.constant'
+  auto tensorType = dyn_cast<TensorType>(containerType);
+  if (tensorType && tensorType.hasStaticShape()) {
+    auto denseElementsAttr = DenseElementsAttr::get(tensorType, attr);
+    return builder.create<arith::ConstantOp>(loc, tensorType, denseElementsAttr);
+  }
+
+  // Scalar and dynamically shaped tensor containers need the scalar constant
+  // to be first materialized.
+  Value containerConstant =
+      builder.create<arith::ConstantOp>(loc, attr.getType(), attr);
+
+  // Create tensor splat if necessary
+  if (tensorType) {
+    containerConstant =
+        builder.create<tensor::SplatOp>(loc, containerConstant, containerShape);
+  }
+  return containerConstant;
+}
+
+// Calculate the size of an unranked tensor starting at dimension 'fromDim' up
+// to, but not including, dimension 'toDim'.
+Value getUnrankedTensorSizeRange(OpBuilder &builder, Location loc, Value input,
+                                 Value fromDim, Value toDim, Value one) {
+  auto loop = builder.create<scf::ForOp>(
+      loc,
+      fromDim,  // lowerBound
+      toDim,  // upperBound
+      one,  // step
+      one,  // iterArgs
+      [&](OpBuilder &builder, Location loc, Value index, ValueRange args) {
+        Value size = builder.create<tensor::DimOp>(loc, input, index);
+        Value totalSize = builder.create<arith::MulIOp>(loc, args.front(), size);
+        builder.create<scf::YieldOp>(loc, totalSize);
+      });
+  return loop.getResult(0);
+}
+
+// Obtain the shape of an unranked tensor. This function returns a 1D tensor of
+// size 'rank' and element type 'index'.
+Value getUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
+                             Value rank) {
+  auto shapeType =
+      RankedTensorType::get({ShapedType::kDynamic}, builder.getIndexType());
+  auto shape = builder.create<tensor::GenerateOp>(
+      loc,
+      shapeType,
+      rank,
+      [&](OpBuilder &builder, Location loc, ValueRange args) {
+        Value size = builder.create<tensor::DimOp>(loc, input, args.front());
+        builder.create<tensor::YieldOp>(loc, size);
+      });
+  return shape;
+}
+
+class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+
+  Value convertPerLayerScalarOrRanked(
+      OpBuilder &builder, Location loc, Value input,
+      UniformQuantizedType quantizedType) const {
+
+    auto inputType = input.getType();
+    auto expressedType = cast<FloatType>(quantizedType.getExpressedType());
+    auto storageType = cast<IntegerType>(quantizedType.getStorageType());
+    auto storageContainerType = cloneContainerType(inputType, storageType);
+
+    auto inputShape = getContainerShape(builder, loc, input);
+
+    // Scale and zero point scalars
+    auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale());
+    auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape);
+    auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint());
+    auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape);
+
+    auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+    auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+
+    Value storedValue;
+    if (quantizedType.isSigned()) {
+      storedValue = builder.create<arith::FPToSIOp>(
+          loc, storageContainerType, storedValueAsExpressedType);
+    } else {
+      storedValue = builder.create<arith::FPToUIOp>(
+          loc, storageContainerType, storedValueAsExpressedType);
+    }
+
+    // Clamp stored value if needed
+    if (quantizedType.hasStorageTypeBounds()) {
+      auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
+      auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
+      auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
+      auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
+      if (quantizedType.isSigned()) {
+        storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
+        storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
+      } else {
+        storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
+        storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
+      }
+    }
+
+    return storedValue;
+  }
+  
+  Value convertPerLayerUnranked(
+      OpBuilder &builder, Location loc, Value input,
+      UniformQuantizedType quantizedType) const {
+    auto rank = builder.create<tensor::RankOp>(loc, input);
+    auto inputShape = getUnrankedTensorShape(builder, loc, input, rank);
+    auto inputType = cast<UnrankedTensorType>(input.getType());
+
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+    auto inputSize = getUnrankedTensorSizeRange(builder, loc, input, zero, rank, one);
+
+    // Compute collapsed input shape as a 1D 1-sized index tensor
+    auto collapsedInputShapeType = RankedTensorType::get({1}, builder.getIndexType());
+    auto collapsedInputShape = builder.create<tensor::FromElementsOp>(
+        loc, collapsedInputShapeType, inputSize);
+
+    // Reshape input tensor into 1D
+    auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic},
+                                                    inputType.getElementType());
+    auto collapsedInput = builder.create<tensor::ReshapeOp>(
+        loc, collapsedInputType, input, collapsedInputShape);
+
+    // Now we know how to convert a ranked tensor
+    auto collapsedStoredValue = convertPerLayerScalarOrRanked(
+        builder, loc, collapsedInput, quantizedType);
+
+    // Expand stored value back to the original shape
+    auto expandedStoredValueType =
+        UnrankedTensorType::get(quantizedType.getStorageType());
+    auto expandedStoredValue = builder.create<tensor::ReshapeOp>(
+        loc, expandedStoredValueType, collapsedStoredValue, inputShape);
+    return expandedStoredValue;
+  }
+
+public:
+  using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto input = op.getInput();
+    auto resultScalarType = getScalarType(op.getResult().getType());
+
+    // Per-layer vs per-channel quantization
+    Value storedValue;
+    if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
+      storedValue = isa<UnrankedTensorType>(input.getType()) ?
+          convertPerLayerUnranked(rewriter, loc, input, quantizedType) :
+          convertPerLayerScalarOrRanked(rewriter, loc, input, quantizedType);
+    } else if (auto quantizedType = dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
+      // FIXM
+    } else {
+      llvm_unreachable("unexpected quantized type");
+    }
+    
+    // Cast stored value to result quantized value
+    rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
+        op, op.getResult().getType(), storedValue);
+    return success();
+  }
+};
+
 struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
   void runOnOperation() override {
-    Operation *parentOp = getOperation();
+    RewritePatternSet patterns(&getContext());
+    populateLowerQuantOpsPatterns(patterns);
+
+    ConversionTarget target(getContext());
+    target.addLegalOp<quant::StorageCastOp>();
+    target.addIllegalDialect<quant::QuantDialect>();
+    target.addLegalDialect<
+      arith::ArithDialect,
+      linalg::LinalgDialect,
+      scf::SCFDialect,
+      tensor::TensorDialect
+    >();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
   }
 };
 
 } // namespace
 
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
+  patterns.add<
+    DequantizeCastOpConversion,
+    QuantizeCastOpConversion
+  >(patterns.getContext());
+}
+
 } // namespace quant
 } // namespace mlir

>From 8f8f6de514a07452f60442464fdd8a77a166a9f4 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 11 Jul 2024 18:01:31 -0400
Subject: [PATCH 04/18] Custom formats for 'quant' ops and use of 'shape'
 dialect ops for 'lower-quant-ops'

---
 .../include/mlir/Dialect/Quant/IR/QuantOps.td | 13 +----
 .../mlir/Dialect/Quant/Transforms/Passes.td   |  2 +-
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        | 13 -----
 .../Quant/Transforms/LowerQuantOps.cpp        | 57 ++++---------------
 4 files changed, 15 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 48e2496203ff0..7a6d270dbb6e9 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -44,6 +44,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
   }];
   let arguments = (ins quant_RealValueType:$input);
   let results = (outs quant_RealValueType:$result);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
 }
 
 def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
@@ -72,16 +73,7 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
   }];
   let arguments = (ins quant_RealValueType:$input);
   let results = (outs quant_RealValueType:$result);
-
-  let extraClassDeclaration = [{
-
-    /// Return the primitive (scalar or tensor element) type of the float input.
-    FloatType getFloatType();
-
-    /// Return the primitive (scalar or tensor element) type of the quantized
-    /// result.
-    quant::UniformQuantizedType getQuantizedType();
-  }];
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
 }
 
 def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
@@ -107,6 +99,7 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
   }];
   let arguments = (ins quant_RealOrStorageValueType:$input);
   let results = (outs quant_RealOrStorageValueType:$result);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
   let hasFolder = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index 19aa37f653c01..56e10688b0c98 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -26,7 +26,7 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
     "arith::ArithDialect",
     "linalg::LinalgDialect",
     "quant::QuantDialect",
-    "scf::SCFDialect",
+    "shape::ShapeDialect",
     "tensor::TensorDialect"
   ];
 }
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 2eafa348f6906..e04ca7eb7e715 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -52,19 +52,6 @@ void QuantDialect::initialize() {
 }
 
 
-//===----------------------------------------------------------------------===//
-// QuantizeCastOp
-//===----------------------------------------------------------------------===//
-
-FloatType QuantizeCastOp::getFloatType() {
-  return cast<FloatType>(getPrimitiveType(getInput().getType()));
-}
-
-UniformQuantizedType QuantizeCastOp::getQuantizedType() {
-  return cast<UniformQuantizedType>(getPrimitiveType(getResult().getType()));
-}
-
-
 //===----------------------------------------------------------------------===//
 // StorageCastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1067dc4f950d3..1daafdd715155 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -16,7 +16,7 @@
 #include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/Dialect/Quant/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -104,41 +104,6 @@ Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
   return containerConstant;
 }
 
-// Calculate the size of an unranked tensor starting at dimension 'fromDim' up
-// to, but not including, dimension 'toDim'.
-Value getUnrankedTensorSizeRange(OpBuilder &builder, Location loc, Value input,
-                                 Value fromDim, Value toDim, Value one) {
-  auto loop = builder.create<scf::ForOp>(
-      loc,
-      fromDim,  // lowerBound
-      toDim,  // upperBound
-      one,  // step
-      one,  // iterArgs
-      [&](OpBuilder &builder, Location loc, Value index, ValueRange args) {
-        Value size = builder.create<tensor::DimOp>(loc, input, index);
-        Value totalSize = builder.create<arith::MulIOp>(loc, args.front(), size);
-        builder.create<scf::YieldOp>(loc, totalSize);
-      });
-  return loop.getResult(0);
-}
-
-// Obtain the shape of an unranked tensor. This function returns a 1D tensor of
-// size 'rank' and element type 'index'.
-Value getUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
-                             Value rank) {
-  auto shapeType =
-      RankedTensorType::get({ShapedType::kDynamic}, builder.getIndexType());
-  auto shape = builder.create<tensor::GenerateOp>(
-      loc,
-      shapeType,
-      rank,
-      [&](OpBuilder &builder, Location loc, ValueRange args) {
-        Value size = builder.create<tensor::DimOp>(loc, input, args.front());
-        builder.create<tensor::YieldOp>(loc, size);
-      });
-  return shape;
-}
-
 class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
 
   Value convertPerLayerScalarOrRanked(
@@ -191,18 +156,18 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
   Value convertPerLayerUnranked(
       OpBuilder &builder, Location loc, Value input,
       UniformQuantizedType quantizedType) const {
-    auto rank = builder.create<tensor::RankOp>(loc, input);
-    auto inputShape = getUnrankedTensorShape(builder, loc, input, rank);
+    auto *context = builder.getContext();
     auto inputType = cast<UnrankedTensorType>(input.getType());
 
-    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
-    auto inputSize = getUnrankedTensorSizeRange(builder, loc, input, zero, rank, one);
+    auto shapeType = shape::getExtentTensorType(context);
+    auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+    Value inputSize = builder.create<shape::NumElementsOp>(
+        loc, builder.getIndexType(), inputShape);
 
-    // Compute collapsed input shape as a 1D 1-sized index tensor
-    auto collapsedInputShapeType = RankedTensorType::get({1}, builder.getIndexType());
+    // Turn input size into 1D tensor
+    auto collapsedShapeType = shape::getExtentTensorType(context, 1);
     auto collapsedInputShape = builder.create<tensor::FromElementsOp>(
-        loc, collapsedInputShapeType, inputSize);
+        loc, collapsedShapeType, inputSize);
 
     // Reshape input tensor into 1D
     auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic},
@@ -210,7 +175,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
     auto collapsedInput = builder.create<tensor::ReshapeOp>(
         loc, collapsedInputType, input, collapsedInputShape);
 
-    // Now we know how to convert a ranked tensor
+    // We now know how to deal with a 1D ranked input
     auto collapsedStoredValue = convertPerLayerScalarOrRanked(
         builder, loc, collapsedInput, quantizedType);
 
@@ -262,7 +227,7 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
     target.addLegalDialect<
       arith::ArithDialect,
       linalg::LinalgDialect,
-      scf::SCFDialect,
+      shape::ShapeDialect,
       tensor::TensorDialect
     >();
 

>From ba88a9c38a212637c3446aca204d7d8f767fad33 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 15 Jul 2024 18:08:50 -0400
Subject: [PATCH 05/18] Progress in per-channel quantization

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 237 ++++++++++++------
 1 file changed, 167 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1daafdd715155..f5f4de807c8f0 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -104,87 +104,186 @@ Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
   return containerConstant;
 }
 
+std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
+                                              Value input) {
+  // Get unranked input shape and total size
+  auto *context = builder.getContext();
+  auto shapeType = shape::getExtentTensorType(context);
+  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+  Value inputSize = builder.create<shape::NumElementsOp>(
+      loc, builder.getIndexType(), inputShape);
+
+  // Turn input size into 1D tensor
+  auto flatShapeType = shape::getExtentTensorType(context, 1);
+  auto flatInputShape = builder.create<tensor::FromElementsOp>(
+      loc, flatShapeType, inputSize);
+
+  // Reshape input tensor into 1D
+  auto inputType = cast<UnrankedTensorType>(input.getType());
+  auto flatInputType =
+      RankedTensorType::get({ShapedType::kDynamic}, inputType.getElementType());
+  auto flatInput = builder.create<tensor::ReshapeOp>(
+      loc, flatInputType, input, flatInputShape);
+  return std::make_pair(flatInput, inputShape);
+}
+
+Value restoreUnrankedTensor(OpBuilder &builder, Location loc, Value input,
+                            Value shape) {
+  auto inputType = cast<TensorType>(input.getType());
+  auto elementType = inputType.getElementType();
+  auto unrankedType = UnrankedTensorType::get(elementType);
+  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, shape);
+}
+
+Value materializeScales(OpBuilder &builder, Location loc,
+                        UniformQuantizedPerAxisType quantizedType) {
+  auto scales = quantizedType.getScales();
+  auto expressedType = quantizedType.getExpressedType();
+  auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
+    return builder.getFloatAttr(expressedType, scale);
+  });
+  auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
+  return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+}
+
+Value materializeZeroPoints(OpBuilder &builder, Location loc,
+                            UniformQuantizedPerAxisType quantizedType) {
+  auto zeroPoints = quantizedType.getZeroPoints();
+  auto expressedType = quantizedType.getExpressedType();
+  auto zeroPointAttrs = llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
+    return builder.getFloatAttr(expressedType, static_cast<double>(zeroPoint));
+  });
+  auto tensorType = RankedTensorType::get({(int64_t) zeroPoints.size()}, expressedType);
+  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
+  return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+}
+
+Value quantizeValue(OpBuilder &builder, Location loc, Value input,
+                    ArrayRef<OpFoldResult> inputShape, Value scale,
+                    Value zeroPoint, QuantizedType quantizedType) {
+  auto inputType = input.getType();
+  auto storageType = cast<IntegerType>(quantizedType.getStorageType());
+  auto storageContainerType = cloneContainerType(inputType, storageType);
+
+  auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+  auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+
+  Value storedValue;
+  if (quantizedType.isSigned()) {
+    storedValue = builder.create<arith::FPToSIOp>(
+        loc, storageContainerType, storedValueAsExpressedType);
+  } else {
+    storedValue = builder.create<arith::FPToUIOp>(
+        loc, storageContainerType, storedValueAsExpressedType);
+  }
+
+  // Clamp stored value if needed
+  if (quantizedType.hasStorageTypeBounds()) {
+    auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
+    auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
+    auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
+    auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
+    if (quantizedType.isSigned()) {
+      storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
+      storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
+    } else {
+      storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
+      storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
+    }
+  }
+
+  return storedValue;
+}
+
 class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
 
-  Value convertPerLayerScalarOrRanked(
-      OpBuilder &builder, Location loc, Value input,
-      UniformQuantizedType quantizedType) const {
+  Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
+                              UniformQuantizedType quantizedType) const {
 
     auto inputType = input.getType();
     auto expressedType = cast<FloatType>(quantizedType.getExpressedType());
-    auto storageType = cast<IntegerType>(quantizedType.getStorageType());
-    auto storageContainerType = cloneContainerType(inputType, storageType);
 
+    // Create scale and zero point constants
     auto inputShape = getContainerShape(builder, loc, input);
-
-    // Scale and zero point scalars
     auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale());
     auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape);
     auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint());
     auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape);
 
-    auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
-    auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+    return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+                         quantizedType);
+  }
 
-    Value storedValue;
-    if (quantizedType.isSigned()) {
-      storedValue = builder.create<arith::FPToSIOp>(
-          loc, storageContainerType, storedValueAsExpressedType);
-    } else {
-      storedValue = builder.create<arith::FPToUIOp>(
-          loc, storageContainerType, storedValueAsExpressedType);
-    }
+  Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
+                        UniformQuantizedType quantizedType) const {
+    // Flatten input if unranked
+    bool isUnranked = isa<UnrankedTensorType>(input.getType());
+    Value shape;
+    if (isUnranked)
+      std::tie(input, shape) = flattenUnrankedTensor(builder, loc, input);
 
-    // Clamp stored value if needed
-    if (quantizedType.hasStorageTypeBounds()) {
-      auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
-      auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
-      auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
-      auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
-      if (quantizedType.isSigned()) {
-        storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
-        storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
-      } else {
-        storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
-        storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
-      }
-    }
+    // Process ranked tensor
+    auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
 
-    return storedValue;
+    // Restore original shape if unranked
+    if (isUnranked)
+      result = restoreUnrankedTensor(builder, loc, result, shape);
+
+    return result;
   }
-  
-  Value convertPerLayerUnranked(
-      OpBuilder &builder, Location loc, Value input,
-      UniformQuantizedType quantizedType) const {
+
+  Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
+                                UniformQuantizedPerAxisType quantizedType,
+                                int32_t channelAxis) const {
     auto *context = builder.getContext();
-    auto inputType = cast<UnrankedTensorType>(input.getType());
-
-    auto shapeType = shape::getExtentTensorType(context);
-    auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
-    Value inputSize = builder.create<shape::NumElementsOp>(
-        loc, builder.getIndexType(), inputShape);
-
-    // Turn input size into 1D tensor
-    auto collapsedShapeType = shape::getExtentTensorType(context, 1);
-    auto collapsedInputShape = builder.create<tensor::FromElementsOp>(
-        loc, collapsedShapeType, inputSize);
-
-    // Reshape input tensor into 1D
-    auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic},
-                                                    inputType.getElementType());
-    auto collapsedInput = builder.create<tensor::ReshapeOp>(
-        loc, collapsedInputType, input, collapsedInputShape);
-
-    // We now know how to deal with a 1D ranked input
-    auto collapsedStoredValue = convertPerLayerScalarOrRanked(
-        builder, loc, collapsedInput, quantizedType);
-
-    // Expand stored value back to the original shape
-    auto expandedStoredValueType =
-        UnrankedTensorType::get(quantizedType.getStorageType());
-    auto expandedStoredValue = builder.create<tensor::ReshapeOp>(
-        loc, expandedStoredValueType, collapsedStoredValue, inputShape);
-    return expandedStoredValue;
+
+    auto inputType = cast<RankedTensorType>(input.getType());
+    auto inputRank = inputType.getRank();
+
+    auto scales = materializeScales(builder, loc, quantizedType);
+    auto zeroPoints = materializeZeroPoints(builder, loc, quantizedType);
+
+    auto storageType = quantizedType.getStorageType();
+    auto initShape = tensor::getMixedSizes(builder, loc, input);
+    Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+
+    SmallVector<utils::IteratorType> iteratorTypes(
+        inputRank, utils::IteratorType::parallel);
+    auto channelAxisAffineMap = AffineMap::get(
+        inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+    SmallVector<AffineMap> indexingMaps{
+      builder.getMultiDimIdentityMap(inputRank),
+      channelAxisAffineMap,
+      channelAxisAffineMap,
+      builder.getMultiDimIdentityMap(inputRank)
+    };
+    auto storedValue = builder.create<linalg::GenericOp>(
+        loc,
+        init.getType(),  // resultType
+        ValueRange{input, scales, zeroPoints},  // inputs
+        ValueRange{init},  // outputs
+        indexingMaps,
+        iteratorTypes,
+        [&](OpBuilder& builder, Location loc, ValueRange args) {
+          assert(args.size() == 4);
+          auto expressedValue = args[0];
+          auto scale = args[1];
+          auto zeroPoint = args[2];
+
+          auto storedValue = quantizeValue(builder, loc, expressedValue, {},
+                                           scale, zeroPoint, quantizedType);
+
+          builder.create<linalg::YieldOp>(loc, storedValue);
+        })
+        .getResult(0);
+
+    return storedValue;
+  }
+
+  Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
+                          UniformQuantizedPerAxisType quantizedType) const {
+    return convertPerChannelRanked(builder, loc, input, quantizedType, quantizedType.getQuantizedDimension());
   }
 
 public:
@@ -197,18 +296,16 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
     auto input = op.getInput();
     auto resultScalarType = getScalarType(op.getResult().getType());
 
-    // Per-layer vs per-channel quantization
+    // Flatten unranked tensor input
     Value storedValue;
     if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
-      storedValue = isa<UnrankedTensorType>(input.getType()) ?
-          convertPerLayerUnranked(rewriter, loc, input, quantizedType) :
-          convertPerLayerScalarOrRanked(rewriter, loc, input, quantizedType);
+      storedValue = convertPerLayer(rewriter, loc, input, quantizedType);
     } else if (auto quantizedType = dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
-      // FIXM
+      storedValue = convertPerChannel(rewriter, loc, input, quantizedType);
     } else {
       llvm_unreachable("unexpected quantized type");
     }
-    
+
     // Cast stored value to result quantized value
     rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
         op, op.getResult().getType(), storedValue);

>From 6b9caccad131a0d2e6cf22419b0486cf01f1b28b Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 16 Jul 2024 10:15:24 -0400
Subject: [PATCH 06/18] Support for unranked tensor in per-channel quantization

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 72 +++++++++++++++++--
 1 file changed, 66 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index f5f4de807c8f0..49ef8ea354070 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -120,10 +120,52 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 
   // Reshape input tensor into 1D
   auto inputType = cast<UnrankedTensorType>(input.getType());
+  auto elementType = inputType.getElementType();
   auto flatInputType =
-      RankedTensorType::get({ShapedType::kDynamic}, inputType.getElementType());
+      RankedTensorType::get({ShapedType::kDynamic}, elementType);
+  auto flatInput = builder.create<tensor::ReshapeOp>(
+      loc, flatInputType, input, flatInputShape);
+  return std::make_pair(flatInput, inputShape);
+}
+
+std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
+                                                        Location loc,
+                                                        Value input,
+                                                        int64_t axis) {
+  // Get full tensor shape
+  auto *context = builder.getContext();
+  auto indexType = builder.getIndexType();
+  auto shapeType = shape::getExtentTensorType(context);
+  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+
+  // Get shape and sizes on left and right of axis
+  auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
+  auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
+  auto shapeLeft = builder.create<shape::SplitAtOp>(
+      loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
+      .getResult(0);
+  auto sizeLeft = builder.create<shape::NumElementsOp>(
+      loc, indexType, shapeLeft);
+  auto shapeRight = builder.create<shape::SplitAtOp>(
+      loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
+      .getResult(1);
+  auto sizeRight = builder.create<shape::NumElementsOp>(
+      loc, indexType, shapeRight);
+  Value axisSize = builder.create<tensor::DimOp>(loc, input, axisValue);
+
+  // Compute flat input shape as a 3-element 1D tensor
+  auto flatShapeType = shape::getExtentTensorType(context, 3);
+  auto flatInputShape = builder.create<tensor::FromElementsOp>(
+      loc, flatShapeType, ValueRange{sizeLeft, axisSize, sizeRight});
+
+  // Reshape input to 3D tensor
+  auto inputType = cast<UnrankedTensorType>(input.getType());
+  auto elementType = inputType.getElementType();
+  SmallVector<int64_t> flatInputDims(3, ShapedType::kDynamic);
+  auto flatInputType = RankedTensorType::get(flatInputDims, elementType);
   auto flatInput = builder.create<tensor::ReshapeOp>(
       loc, flatInputType, input, flatInputShape);
+
   return std::make_pair(flatInput, inputShape);
 }
 
@@ -219,23 +261,23 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
                         UniformQuantizedType quantizedType) const {
     // Flatten input if unranked
     bool isUnranked = isa<UnrankedTensorType>(input.getType());
-    Value shape;
+    Value inputShape;
     if (isUnranked)
-      std::tie(input, shape) = flattenUnrankedTensor(builder, loc, input);
+      std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
 
     // Process ranked tensor
     auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
 
     // Restore original shape if unranked
     if (isUnranked)
-      result = restoreUnrankedTensor(builder, loc, result, shape);
+      result = restoreUnrankedTensor(builder, loc, result, inputShape);
 
     return result;
   }
 
   Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
                                 UniformQuantizedPerAxisType quantizedType,
-                                int32_t channelAxis) const {
+                                int64_t channelAxis) const {
     auto *context = builder.getContext();
 
     auto inputType = cast<RankedTensorType>(input.getType());
@@ -283,7 +325,25 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
 
   Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
                           UniformQuantizedPerAxisType quantizedType) const {
-    return convertPerChannelRanked(builder, loc, input, quantizedType, quantizedType.getQuantizedDimension());
+    // Flatten unranked tensor if necessary
+    bool isUnranked = isa<UnrankedTensorType>(input.getType());
+    int64_t channelAxis = quantizedType.getQuantizedDimension();
+    Value inputShape;
+    if (isUnranked) {
+      std::tie(input, inputShape) =
+          flattenUnrankedTensorAroundAxis(builder, loc, input, channelAxis);
+      channelAxis = 1;
+    }
+
+    // Work on a ranked tensor
+    auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
+                                          channelAxis);
+
+    // Restore original tensor shape if unranked
+    if (isUnranked)
+      result = restoreUnrankedTensor(builder, loc, result, inputShape);
+
+    return result;
   }
 
 public:

>From b2fee689c74012651662d4edc8c00988aea5344b Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 16 Jul 2024 15:24:15 -0400
Subject: [PATCH 07/18] Refactored 'quant.qcast' lowering. Ready to begin
 'quant.dcast'

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 270 ++++++++++--------
 1 file changed, 158 insertions(+), 112 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 49ef8ea354070..68a4328128292 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -49,59 +49,50 @@ class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeC
 // QuantizeCastOp
 //===----------------------------------------------------------------------===//
 
-// If 'containerType' is a tensor, return its element type. If it is a scalar,
+// If 'inputType' is a tensor, return its element type. If it is a scalar,
 // return it as is.
-Type getScalarType(Type containerType) {
-  if (auto tensorType = dyn_cast<TensorType>(containerType))
+Type getScalarType(Type inputType) {
+  if (auto tensorType = dyn_cast<TensorType>(inputType))
     return tensorType.getElementType();
-  return containerType;
+  return inputType;
 }
 
-// Return the shape of a container as a combination of attributes (static
-// dimensions) and values (dynamic dimensions). If 'container' is a scalar,
-// an empty list is returned. If 'container' is a tensor, its shape is returned.
-SmallVector<OpFoldResult> getContainerShape(OpBuilder &builder, Location loc,
-                                            Value container) {
-  if (isa<TensorType>(container.getType()))
-    return tensor::getMixedSizes(builder, loc, container);
+// Return the shape of an input value as a list of attributes (static dimensions)
+// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
+// returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult>
+getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+  if (isa<TensorType>(input.getType()))
+    return tensor::getMixedSizes(builder, loc, input);
   return {};
 }
 
-// Clone the given 'containerType' with the new given 'elementType'. If
-// 'containerType' is a scalar type, there is nothing to clone, and
-// 'elementType' itself is returned. If 'constainerType' is a tensor, its
-// shape is cloned but the new element type is used.
-Type cloneContainerType(Type containerType, Type elementType) {
-  if (auto tensorType = dyn_cast<TensorType>(containerType))
+// If 'referenceType' is a scalar, return 'elementType' as is. If
+// 'referenceType' is a tensor, return another tensor with the same shape and
+// elements of type 'elementType'.
+Type getScalarOrTensorType(Type elementType, Type referenceType) {
+  if (auto tensorType = dyn_cast<TensorType>(referenceType))
     return tensorType.clone(elementType);
   return elementType;
 }
 
-// Get a scalar or tensor constant containing the value given in 'attr'.
-// If 'containerType' is a scalar, a scalar constant is returned. If
-// 'containerType' is a tensor, a tensor splat of shape 'containerShape' is
-// returned.
-Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
-                           Type containerType,
-                           ArrayRef<OpFoldResult> containerShape) {
-  // A statically shaped tensor can be created with 'arith.constant'
-  auto tensorType = dyn_cast<TensorType>(containerType);
-  if (tensorType && tensorType.hasStaticShape()) {
-    auto denseElementsAttr = DenseElementsAttr::get(tensorType, attr);
-    return builder.create<arith::ConstantOp>(loc, tensorType, denseElementsAttr);
+// Return a constant with the given value. If 'referenceType' is a tensor, a
+// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
+// scalar, 'referenceShape' is ignored and a scalar constant is returned.
+Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
+                                Type referenceType,
+                                ArrayRef<OpFoldResult> referenceShape) {
+  // If the result type is a scalar, return the unmodified scalar constant.
+  auto tensorType = dyn_cast<TensorType>(referenceType);
+  if (!tensorType) {
+    assert(referenceShape.empty());
+    return scalar;
   }
 
-  // Scalar and dynamically shaped tensor containers need the scalar constant
-  // to be first materialized.
-  Value containerConstant =
-      builder.create<arith::ConstantOp>(loc, attr.getType(), attr);
-
-  // Create tensor splat if necessary
-  if (tensorType) {
-    containerConstant =
-        builder.create<tensor::SplatOp>(loc, containerConstant, containerShape);
-  }
-  return containerConstant;
+  // Create tensor splat
+  auto tensorConstant =
+      builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
+  return tensorConstant;
 }
 
 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
@@ -131,7 +122,8 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
                                                         Location loc,
                                                         Value input,
-                                                        int64_t axis) {
+                                                        int64_t axis,
+                                                        int64_t axisSize) {
   // Get full tensor shape
   auto *context = builder.getContext();
   auto indexType = builder.getIndexType();
@@ -151,34 +143,34 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
       .getResult(1);
   auto sizeRight = builder.create<shape::NumElementsOp>(
       loc, indexType, shapeRight);
-  Value axisSize = builder.create<tensor::DimOp>(loc, input, axisValue);
 
   // Compute flat input shape as a 3-element 1D tensor
+  auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
   auto flatShapeType = shape::getExtentTensorType(context, 3);
   auto flatInputShape = builder.create<tensor::FromElementsOp>(
-      loc, flatShapeType, ValueRange{sizeLeft, axisSize, sizeRight});
+      loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
 
   // Reshape input to 3D tensor
   auto inputType = cast<UnrankedTensorType>(input.getType());
   auto elementType = inputType.getElementType();
-  SmallVector<int64_t> flatInputDims(3, ShapedType::kDynamic);
-  auto flatInputType = RankedTensorType::get(flatInputDims, elementType);
+  auto flatInputType = RankedTensorType::get(
+      {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
   auto flatInput = builder.create<tensor::ReshapeOp>(
       loc, flatInputType, input, flatInputShape);
 
   return std::make_pair(flatInput, inputShape);
 }
 
-Value restoreUnrankedTensor(OpBuilder &builder, Location loc, Value input,
-                            Value shape) {
-  auto inputType = cast<TensorType>(input.getType());
+Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
+                                 Value inputShape) {
+  auto inputType = cast<RankedTensorType>(input.getType());
   auto elementType = inputType.getElementType();
   auto unrankedType = UnrankedTensorType::get(elementType);
-  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, shape);
+  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
 }
 
-Value materializeScales(OpBuilder &builder, Location loc,
-                        UniformQuantizedPerAxisType quantizedType) {
+Value materializePerChannelScales(OpBuilder &builder, Location loc,
+                                  UniformQuantizedPerAxisType quantizedType) {
   auto scales = quantizedType.getScales();
   auto expressedType = quantizedType.getExpressedType();
   auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
@@ -189,52 +181,100 @@ Value materializeScales(OpBuilder &builder, Location loc,
   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
 }
 
-Value materializeZeroPoints(OpBuilder &builder, Location loc,
+Value materializePerChannelZeroPoints(OpBuilder &builder, Location loc,
                             UniformQuantizedPerAxisType quantizedType) {
   auto zeroPoints = quantizedType.getZeroPoints();
-  auto expressedType = quantizedType.getExpressedType();
-  auto zeroPointAttrs = llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
-    return builder.getFloatAttr(expressedType, static_cast<double>(zeroPoint));
-  });
-  auto tensorType = RankedTensorType::get({(int64_t) zeroPoints.size()}, expressedType);
+  auto storageType = quantizedType.getStorageType();
+  auto zeroPointAttrs = llvm::map_to_vector(
+      zeroPoints,
+      [&](int64_t zeroPoint) -> Attribute {
+        return builder.getIntegerAttr(storageType, zeroPoint);
+      });
+  auto tensorType =
+      RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
   auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
   return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
 }
 
-Value quantizeValue(OpBuilder &builder, Location loc, Value input,
-                    ArrayRef<OpFoldResult> inputShape, Value scale,
-                    Value zeroPoint, QuantizedType quantizedType) {
-  auto inputType = input.getType();
-  auto storageType = cast<IntegerType>(quantizedType.getStorageType());
-  auto storageContainerType = cloneContainerType(inputType, storageType);
-
-  auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
-  auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+                          ArrayRef<OpFoldResult> inputShape,
+                          QuantizedType quantizedType) {
+  // If quantized type does not narrow down the storage type range, there is
+  // nothing to do.
+  if (!quantizedType.hasStorageTypeBounds())
+    return input;
 
-  Value storedValue;
+  // Materialize bounds
+  auto inputType = input.getType();
+  auto storageType = quantizedType.getStorageType();
+  auto storageMinScalar = builder.create<arith::ConstantIntOp>(
+      loc, quantizedType.getStorageTypeMin(), storageType);
+  auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
+      loc, quantizedType.getStorageTypeMax(), storageType);
+  auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
+                                              inputType, inputShape);
+  auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
+                                              inputType, inputShape);
+
+  // Clamp
   if (quantizedType.isSigned()) {
-    storedValue = builder.create<arith::FPToSIOp>(
-        loc, storageContainerType, storedValueAsExpressedType);
+    input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
+    input = builder.create<arith::MinSIOp>(loc, input, storageMax);
   } else {
-    storedValue = builder.create<arith::FPToUIOp>(
-        loc, storageContainerType, storedValueAsExpressedType);
+    input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
+    input = builder.create<arith::MinUIOp>(loc, input, storageMax);
   }
+  return input;
+}
 
-  // Clamp stored value if needed
-  if (quantizedType.hasStorageTypeBounds()) {
-    auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
-    auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
-    auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
-    auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
-    if (quantizedType.isSigned()) {
-      storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
-      storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
-    } else {
-      storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
-      storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
-    }
-  }
+Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
+                            Type resultType, bool isSigned) {
+  if (isSigned)
+    return builder.create<arith::FPToSIOp>(loc, resultType, input);
+  return builder.create<arith::FPToUIOp>(loc, resultType, input);
+}
+
+Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
+                            Type resultType, bool isSigned) {
+  if (isSigned)
+    return builder.create<arith::SIToFPOp>(loc, resultType, input);
+  return builder.create<arith::UIToFPOp>(loc, resultType, input);
+}
 
+// Quantize a floating-point input using the given scale, input shape, and
+// storage type bounds in the given quantized type.
+Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+                             ArrayRef<OpFoldResult> inputShape, Value scale,
+                             Value zeroPoint, QuantizedType quantizedType) {
+  // Convert scale and zero point to tensors if necessary
+  auto inputType = input.getType();
+  scale = getScalarOrTensorConstant(
+      builder, loc, scale, inputType, inputShape);
+  zeroPoint = getScalarOrTensorConstant(
+      builder, loc, zeroPoint, inputType, inputShape);
+
+  // Convert zero point from storage to expressed type
+  auto expressedScalarOrTensorType =
+      getScalarOrTensorType(quantizedType.getExpressedType(), inputType);
+  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+                                    expressedScalarOrTensorType,
+                                    quantizedType.isSigned());
+
+  // Scale input and add zero point
+  auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+  auto storedValueAsExpressedType =
+      builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+
+  // Convert to storage type
+  auto storageScalarOrTensorType =
+      getScalarOrTensorType(quantizedType.getStorageType(), inputType);
+  auto storedValue = convertFloatToInteger(
+      builder, loc, storedValueAsExpressedType, storageScalarOrTensorType,
+      quantizedType.isSigned());
+
+  // Clamp stored value it if the storage type is bound
+  storedValue =
+      clampScalarOrTensor(builder, loc, storedValue, inputShape, quantizedType);
   return storedValue;
 }
 
@@ -243,18 +283,21 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
   Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
                               UniformQuantizedType quantizedType) const {
 
-    auto inputType = input.getType();
-    auto expressedType = cast<FloatType>(quantizedType.getExpressedType());
-
     // Create scale and zero point constants
-    auto inputShape = getContainerShape(builder, loc, input);
-    auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale());
-    auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape);
-    auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint());
-    auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape);
-
-    return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
-                         quantizedType);
+    auto expressedType = quantizedType.getExpressedType();
+    auto storageType = quantizedType.getStorageType();
+    auto scaleAttr =
+        builder.getFloatAttr(expressedType, quantizedType.getScale());
+    auto scale =
+        builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+    auto zeroPointAttr =
+        builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+    auto zeroPoint =
+        builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+    auto inputShape = getScalarOrTensorShape(builder, loc, input);
+    return quantizeScalarOrTensor(builder, loc, input, inputShape, scale,
+                                  zeroPoint, quantizedType);
   }
 
   Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
@@ -270,7 +313,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
 
     // Restore original shape if unranked
     if (isUnranked)
-      result = restoreUnrankedTensor(builder, loc, result, inputShape);
+      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
 
     return result;
   }
@@ -283,8 +326,9 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
     auto inputType = cast<RankedTensorType>(input.getType());
     auto inputRank = inputType.getRank();
 
-    auto scales = materializeScales(builder, loc, quantizedType);
-    auto zeroPoints = materializeZeroPoints(builder, loc, quantizedType);
+    auto scales = materializePerChannelScales(builder, loc, quantizedType);
+    auto zeroPoints =
+        materializePerChannelZeroPoints(builder, loc, quantizedType);
 
     auto storageType = quantizedType.getStorageType();
     auto initShape = tensor::getMixedSizes(builder, loc, input);
@@ -313,10 +357,10 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
           auto scale = args[1];
           auto zeroPoint = args[2];
 
-          auto storedValue = quantizeValue(builder, loc, expressedValue, {},
-                                           scale, zeroPoint, quantizedType);
+          auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
+                                               scale, zeroPoint, quantizedType);
 
-          builder.create<linalg::YieldOp>(loc, storedValue);
+          builder.create<linalg::YieldOp>(loc, result);
         })
         .getResult(0);
 
@@ -325,13 +369,14 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
 
   Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
                           UniformQuantizedPerAxisType quantizedType) const {
-    // Flatten unranked tensor if necessary
+    // Flatten unranked tensor into a 3D ranked tensor if necessary
     bool isUnranked = isa<UnrankedTensorType>(input.getType());
     int64_t channelAxis = quantizedType.getQuantizedDimension();
+    int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
     Value inputShape;
     if (isUnranked) {
-      std::tie(input, inputShape) =
-          flattenUnrankedTensorAroundAxis(builder, loc, input, channelAxis);
+      std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+          builder, loc, input, channelAxis, channelAxisSize);
       channelAxis = 1;
     }
 
@@ -341,7 +386,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
 
     // Restore original tensor shape if unranked
     if (isUnranked)
-      result = restoreUnrankedTensor(builder, loc, result, inputShape);
+      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
 
     return result;
   }
@@ -357,18 +402,19 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
     auto resultScalarType = getScalarType(op.getResult().getType());
 
     // Flatten unranked tensor input
-    Value storedValue;
+    Value result;
     if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
-      storedValue = convertPerLayer(rewriter, loc, input, quantizedType);
-    } else if (auto quantizedType = dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
-      storedValue = convertPerChannel(rewriter, loc, input, quantizedType);
+      result = convertPerLayer(rewriter, loc, input, quantizedType);
+    } else if (auto quantizedType =
+                   dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
+      result = convertPerChannel(rewriter, loc, input, quantizedType);
     } else {
-      llvm_unreachable("unexpected quantized type");
+      llvm_unreachable("unexpected uniform quantized type");
     }
 
     // Cast stored value to result quantized value
     rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
-        op, op.getResult().getType(), storedValue);
+        op, op.getResult().getType(), result);
     return success();
   }
 };

>From f032c472134642dc44ab6134d2e2646ca13ca341 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 17 Jul 2024 11:31:20 -0400
Subject: [PATCH 08/18] Per-layer quantization unit tests

---
 .../Quant/Transforms/LowerQuantOps.cpp        |  50 +++---
 mlir/test/Dialect/Quant/lower-quant-ops.mlir  | 165 ++++++++++++++++++
 2 files changed, 194 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Dialect/Quant/lower-quant-ops.mlir

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 68a4328128292..e6942899cd638 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Quant/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -181,8 +182,9 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
 }
 
-Value materializePerChannelZeroPoints(OpBuilder &builder, Location loc,
-                            UniformQuantizedPerAxisType quantizedType) {
+Value materializePerChannelZeroPoints(
+    OpBuilder &builder, Location loc,
+    UniformQuantizedPerAxisType quantizedType) {
   auto zeroPoints = quantizedType.getZeroPoints();
   auto storageType = quantizedType.getStorageType();
   auto zeroPointAttrs = llvm::map_to_vector(
@@ -246,36 +248,42 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
 Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
                              ArrayRef<OpFoldResult> inputShape, Value scale,
                              Value zeroPoint, QuantizedType quantizedType) {
-  // Convert scale and zero point to tensors if necessary
+  // Convert scale to tensor if necessary
   auto inputType = input.getType();
   scale = getScalarOrTensorConstant(
       builder, loc, scale, inputType, inputShape);
-  zeroPoint = getScalarOrTensorConstant(
-      builder, loc, zeroPoint, inputType, inputShape);
 
-  // Convert zero point from storage to expressed type
-  auto expressedScalarOrTensorType =
-      getScalarOrTensorType(quantizedType.getExpressedType(), inputType);
-  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
-                                    expressedScalarOrTensorType,
-                                    quantizedType.isSigned());
-
-  // Scale input and add zero point
+  // Scale input
   auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
-  auto storedValueAsExpressedType =
-      builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
 
-  // Convert to storage type
+  // Skip unnecessary computations if no zero point is given
+  Value storedValueFloat = scaledValue;
+  if (matchPattern(zeroPoint, m_NonZero())) {
+    // Convert zero point to tensor if necessary
+    zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+                                          inputShape);
+
+    // Convert zero point from storage to expressed type
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+                                      scale.getType(),
+                                      quantizedType.isSigned());
+
+    // Add zero point to stored value
+    storedValueFloat =
+        builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+  }
+
+  // Convert stored value to storage type
   auto storageScalarOrTensorType =
       getScalarOrTensorType(quantizedType.getStorageType(), inputType);
-  auto storedValue = convertFloatToInteger(
-      builder, loc, storedValueAsExpressedType, storageScalarOrTensorType,
+  auto storedValueInt = convertFloatToInteger(
+      builder, loc, storedValueFloat, storageScalarOrTensorType,
       quantizedType.isSigned());
 
   // Clamp stored value it if the storage type is bound
-  storedValue =
-      clampScalarOrTensor(builder, loc, storedValue, inputShape, quantizedType);
-  return storedValue;
+  auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
+                                                inputShape, quantizedType);
+  return storedValueClamped;
 }
 
 class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
new file mode 100644
index 0000000000000..0e151c514eebc
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @qcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform<i8:f32, 2.000000e+00:10>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8:f32, 2.000000e+00:10>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<i8<-5:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<u8<2:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_FLOAT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias>
+  return %0 : tensor<3x?x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32>
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8>
+
+// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8>
+// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32, 2.0:10>
+func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias>
+  return %0 : tensor<3x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor<?xf32>
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor<?xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<?xf32> to tensor<?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor<?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return %0 : tensor<*x!qalias>
+}
+

>From 3a913979546c071eaae5e681d4313e5abf18e523 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 17 Jul 2024 12:45:58 -0400
Subject: [PATCH 09/18] Completed unit tests for 'quant.qcast'

---
 .../Quant/Transforms/LowerQuantOps.cpp        |   2 +-
 mlir/test/Dialect/Quant/lower-quant-ops.mlir  | 117 ++++++++++++++++++
 2 files changed, 118 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index e6942899cd638..1ffbd7032ae58 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -258,7 +258,7 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
 
   // Skip unnecessary computations if no zero point is given
   Value storedValueFloat = scaledValue;
-  if (matchPattern(zeroPoint, m_NonZero())) {
+  if (!matchPattern(zeroPoint, m_Zero())) {
     // Convert zero point to tensor if necessary
     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
                                           inputShape);
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 0e151c514eebc..1aa14c615c6e5 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -163,3 +163,120 @@ func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
   return %0 : tensor<*x!qalias>
 }
 
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32>
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<4x?x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> {
+  %0 = "quant.qcast"(%arg0) : (tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias>
+  return %0 : tensor<4x?x?x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8>
+
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK:   %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK:   %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8
+// CHECK:   %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8
+// CHECK:   linalg.yield %[[STORED_CLAMPED]] : i8
+// CHECK: } -> tensor<4x2x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
+func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
+  %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias>
+  return %0 : tensor<4x2x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index
+// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x3x?xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor<?x3x?xf32>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor<?x3x?xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xf32>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<?x3x?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+  %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
+  return %0 : tensor<*x!qalias>
+}
+

>From b50b2e12ebc9d981ac0d39b3c69d5c0f5f3c0afe Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 09:34:57 -0400
Subject: [PATCH 10/18] Unit test for 0D tensor

---
 mlir/test/Dialect/Quant/lower-quant-ops.mlir | 29 ++++++++++++++++++--
 1 file changed, 27 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 1aa14c615c6e5..ad3be870916dc 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -86,9 +86,9 @@ func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
 // CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
 // CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
 // CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32>
-// CHECK: %[[STORED_FLOAT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
 
-// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_FLOAT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
 // CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
 
 !qalias = !quant.uniform<i8:f32, 2.0:10>
@@ -99,6 +99,31 @@ func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qal
 
 // -----
 
+// CHECK-LABEL: @qcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
+  %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
+  return %0 : tensor<!qalias>
+}
+
+// -----
+
 // CHECK-LABEL: @qcast_per_layer_ranked_bounds
 // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
 

>From 9fc4be12215ad3bb9940203bda6d0114b05fa9c9 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 12:36:41 -0400
Subject: [PATCH 11/18] Support for 'quant.dcast' and unit tests

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 217 ++++++++++++++++--
 mlir/test/Dialect/Quant/lower-quant-ops.mlir  |  95 ++++++--
 2 files changed, 264 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1ffbd7032ae58..1b2d36dc672cf 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -30,26 +30,6 @@ namespace quant {
 
 namespace {
 
-//===----------------------------------------------------------------------===//
-// DequantizeCastOp
-//===----------------------------------------------------------------------===//
-
-class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
-public:
-  using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    return success();
-  }
-};
-
-
-//===----------------------------------------------------------------------===//
-// QuantizeCastOp
-//===----------------------------------------------------------------------===//
-
 // If 'inputType' is a tensor, return its element type. If it is a scalar,
 // return it as is.
 Type getScalarType(Type inputType) {
@@ -243,8 +223,9 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
   return builder.create<arith::UIToFPOp>(loc, resultType, input);
 }
 
-// Quantize a floating-point input using the given scale, input shape, and
-// storage type bounds in the given quantized type.
+// Quantize a floating-point input using the given input shape, scale, and
+// zero point. The stored value is clamped using the storage bounds encoded in
+// the given quantized type.
 Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
                              ArrayRef<OpFoldResult> inputShape, Value scale,
                              Value zeroPoint, QuantizedType quantizedType) {
@@ -286,6 +267,196 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
   return storedValueClamped;
 }
 
+Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+                               ArrayRef<OpFoldResult> inputShape, Value scale,
+                               Value zeroPoint, QuantizedType quantizedType) {
+  // Convert scale to tensor if necessary
+  auto inputType = input.getType();
+  scale = getScalarOrTensorConstant(
+      builder, loc, scale, inputType, inputShape);
+
+  // Convert stored value to float
+  auto result = convertIntegerToFloat(
+      builder, loc, input, scale.getType(), quantizedType.isSigned());
+
+  // Skip unnecessary computations if no zero point is given
+  if (!matchPattern(zeroPoint, m_Zero())) {
+    // Convert zero point to tensor if necessary
+    zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+                                          inputShape);
+
+    // Convert zero point from storage to expressed type
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+                                      scale.getType(),
+                                      quantizedType.isSigned());
+
+    // Subtract zero point to stored value
+    result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
+  }
+
+  // Multiply by scale
+  result = builder.create<arith::MulFOp>(loc, result, scale);
+  return result;
+}
+
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+  
+  Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
+                              UniformQuantizedType quantizedType) const {
+
+    // Create scale and zero point constants
+    auto expressedType = quantizedType.getExpressedType();
+    auto storageType = quantizedType.getStorageType();
+    auto scaleAttr =
+        builder.getFloatAttr(expressedType, quantizedType.getScale());
+    auto scale =
+        builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+    auto zeroPointAttr =
+        builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+    auto zeroPoint =
+        builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+    auto inputShape = getScalarOrTensorShape(builder, loc, input);
+    return dequantizeScalarOrTensor(builder, loc, input, inputShape, scale,
+                                    zeroPoint, quantizedType);
+  }
+
+  Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
+                        UniformQuantizedType quantizedType) const {
+    // Flatten input if unranked
+    bool isUnranked = isa<UnrankedTensorType>(input.getType());
+    Value inputShape;
+    if (isUnranked)
+      std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+
+    // Process ranked tensor
+    auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
+
+    // Restore original shape if unranked
+    if (isUnranked)
+      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+    return result;
+  }
+
+  Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
+                                UniformQuantizedPerAxisType quantizedType,
+                                int64_t channelAxis) const {
+    auto *context = builder.getContext();
+
+    auto inputType = cast<RankedTensorType>(input.getType());
+    auto inputRank = inputType.getRank();
+
+    auto scales = materializePerChannelScales(builder, loc, quantizedType);
+    auto zeroPoints =
+        materializePerChannelZeroPoints(builder, loc, quantizedType);
+
+    auto storageType = quantizedType.getStorageType();
+    auto initShape = tensor::getMixedSizes(builder, loc, input);
+    Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+
+    SmallVector<utils::IteratorType> iteratorTypes(
+        inputRank, utils::IteratorType::parallel);
+    auto channelAxisAffineMap = AffineMap::get(
+        inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+    SmallVector<AffineMap> indexingMaps{
+      builder.getMultiDimIdentityMap(inputRank),
+      channelAxisAffineMap,
+      channelAxisAffineMap,
+      builder.getMultiDimIdentityMap(inputRank)
+    };
+    auto storedValue = builder.create<linalg::GenericOp>(
+        loc,
+        init.getType(),  // resultType
+        ValueRange{input, scales, zeroPoints},  // inputs
+        ValueRange{init},  // outputs
+        indexingMaps,
+        iteratorTypes,
+        [&](OpBuilder& builder, Location loc, ValueRange args) {
+          assert(args.size() == 4);
+          auto expressedValue = args[0];
+          auto scale = args[1];
+          auto zeroPoint = args[2];
+
+          auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
+                                               scale, zeroPoint, quantizedType);
+
+          builder.create<linalg::YieldOp>(loc, result);
+        })
+        .getResult(0);
+
+    return storedValue;
+  }
+
+  Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
+                          UniformQuantizedPerAxisType quantizedType) const {
+    // Flatten unranked tensor into a 3D ranked tensor if necessary
+    bool isUnranked = isa<UnrankedTensorType>(input.getType());
+    int64_t channelAxis = quantizedType.getQuantizedDimension();
+    int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+    Value inputShape;
+    if (isUnranked) {
+      std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+          builder, loc, input, channelAxis, channelAxisSize);
+      channelAxis = 1;
+    }
+
+    // Work on a ranked tensor
+    auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
+                                          channelAxis);
+
+    // Restore original tensor shape if unranked
+    if (isUnranked)
+      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+    return result;
+  }
+
+public:
+  using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto input = op.getInput();
+    auto quantizedType =
+        cast<QuantizedType>(getScalarType(op.getInput().getType()));
+
+    // Convert quantized input to storage type
+    auto storageScalarOrTensorType =
+        getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
+    input = rewriter.create<quant::StorageCastOp>(
+        loc, storageScalarOrTensorType, input);
+
+    // Flatten unranked tensor input
+    Value result;
+    if (auto uniformQuantizedType =
+            dyn_cast<UniformQuantizedType>(quantizedType)) {
+      result = convertPerLayer(rewriter, loc, input, uniformQuantizedType);
+    } else if (auto uniformQuantizedPerAxisType =
+                   dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
+      result =
+          convertPerChannel(rewriter, loc, input, uniformQuantizedPerAxisType);
+    } else {
+      llvm_unreachable("unexpected quantized type");
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
 class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
 
   Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
@@ -417,7 +588,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
                    dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
       result = convertPerChannel(rewriter, loc, input, quantizedType);
     } else {
-      llvm_unreachable("unexpected uniform quantized type");
+      llvm_unreachable("unexpected quantized type");
     }
 
     // Cast stored value to result quantized value
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index ad3be870916dc..1030d5b20620b 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -72,6 +72,31 @@ func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
 
 // -----
 
+// CHECK-LABEL: @qcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
+  %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
+  return %0 : tensor<!qalias>
+}
+
+// -----
+
 // CHECK-LABEL: @qcast_per_layer_ranked
 // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32>
 
@@ -99,31 +124,6 @@ func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qal
 
 // -----
 
-// CHECK-LABEL: @qcast_per_layer_0d
-// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
-
-// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
-
-// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
-// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
-
-// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
-// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
-// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
-// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
-
-// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
-// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
-
-!qalias = !quant.uniform<i8:f32, 2.0:10>
-func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
-  %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
-  return %0 : tensor<!qalias>
-}
-
-// -----
-
 // CHECK-LABEL: @qcast_per_layer_ranked_bounds
 // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
 
@@ -305,3 +305,48 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias>
   return %0 : tensor<*x!qalias>
 }
 
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<u8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return %0 : f32
+}
+

>From 5b40f256c6dd31f511da5ff58608e9c28732a0b6 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 14:33:38 -0400
Subject: [PATCH 12/18] Refactored quant.qcast and quant.dcast common code

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 418 +++++++-----------
 1 file changed, 164 insertions(+), 254 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1b2d36dc672cf..77be3dd11d3ad 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -223,12 +223,11 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
   return builder.create<arith::UIToFPOp>(loc, resultType, input);
 }
 
-// Quantize a floating-point input using the given input shape, scale, and
-// zero point. The stored value is clamped using the storage bounds encoded in
-// the given quantized type.
-Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
-                             ArrayRef<OpFoldResult> inputShape, Value scale,
-                             Value zeroPoint, QuantizedType quantizedType) {
+// Quantize a scalar or ranked tensor value. The stored value is clamped using 
+// the storage bounds encoded in the given quantized type.
+Value quantizeValue(OpBuilder &builder, Location loc, Value input,
+                    ArrayRef<OpFoldResult> inputShape, Value scale,
+                    Value zeroPoint, QuantizedType quantizedType) {
   // Convert scale to tensor if necessary
   auto inputType = input.getType();
   scale = getScalarOrTensorConstant(
@@ -267,9 +266,10 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
   return storedValueClamped;
 }
 
-Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
-                               ArrayRef<OpFoldResult> inputShape, Value scale,
-                               Value zeroPoint, QuantizedType quantizedType) {
+// Dequantize a scalar or ranked tensor value.
+Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
+                      ArrayRef<OpFoldResult> inputShape, Value scale,
+                      Value zeroPoint, QuantizedType quantizedType) {
   // Convert scale to tensor if necessary
   auto inputType = input.getType();
   scale = getScalarOrTensorConstant(
@@ -299,125 +299,168 @@ Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
   return result;
 }
 
+// Convert a scalar or ranked tensor input with the given scale and zero point
+// values.
+//
+// - input
+//   Scalar or ranked tensor value.
+//
+// - inputShape
+//   If 'input' is a tensor, combination or attributes/values representing its
+//   static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - scale
+//   Scale as a scalar value.
+//
+// - zeroPoint
+//   Zero point as a scalar value.
+//
+// - quantizedType
+//   Scalar quantized type of the result ('quant.qcast') or of the input
+//   ('quant.dcast').
+//
+Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
+                    Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
+                    Value zeroPoint, QuantizedType quantizedType) {
+  if (isa<QuantizeCastOp>(op))
+    return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+                         quantizedType);
+  if (isa<DequantizeCastOp>(op))
+    return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+                           quantizedType);
+  llvm_unreachable("unexpected quant op");
+}
 
-//===----------------------------------------------------------------------===//
-// DequantizeCastOp
-//===----------------------------------------------------------------------===//
+Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
+                            Value input, UniformQuantizedType quantizedType) {
+
+  // Create scale and zero point constants
+  auto expressedType = quantizedType.getExpressedType();
+  auto storageType = quantizedType.getStorageType();
+  auto scaleAttr =
+      builder.getFloatAttr(expressedType, quantizedType.getScale());
+  auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+  auto zeroPointAttr =
+      builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+  auto zeroPoint =
+      builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+  auto inputShape = getScalarOrTensorShape(builder, loc, input);
+  return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
+                       quantizedType);
+}
 
-class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
-  
-  Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
-                              UniformQuantizedType quantizedType) const {
-
-    // Create scale and zero point constants
-    auto expressedType = quantizedType.getExpressedType();
-    auto storageType = quantizedType.getStorageType();
-    auto scaleAttr =
-        builder.getFloatAttr(expressedType, quantizedType.getScale());
-    auto scale =
-        builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
-    auto zeroPointAttr =
-        builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
-    auto zeroPoint =
-        builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
-
-    auto inputShape = getScalarOrTensorShape(builder, loc, input);
-    return dequantizeScalarOrTensor(builder, loc, input, inputShape, scale,
+Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
+                      Value input, UniformQuantizedType quantizedType) {
+  // Flatten input if unranked
+  bool isUnranked = isa<UnrankedTensorType>(input.getType());
+  Value inputShape;
+  if (isUnranked)
+    std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+
+  // Process ranked tensor
+  auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
+
+  // Restore original shape if unranked
+  if (isUnranked)
+    result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+  return result;
+}
+
+Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
+                              Value input,
+                              UniformQuantizedPerAxisType quantizedType,
+                              int64_t channelAxis) {
+  auto *context = builder.getContext();
+
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto inputRank = inputType.getRank();
+
+  auto scales = materializePerChannelScales(builder, loc, quantizedType);
+  auto zeroPoints =
+      materializePerChannelZeroPoints(builder, loc, quantizedType);
+
+  auto storageType = quantizedType.getStorageType();
+  auto initShape = tensor::getMixedSizes(builder, loc, input);
+  Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+
+  SmallVector<utils::IteratorType> iteratorTypes(
+      inputRank, utils::IteratorType::parallel);
+  auto channelAxisAffineMap = AffineMap::get(
+      inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+  SmallVector<AffineMap> indexingMaps{
+    builder.getMultiDimIdentityMap(inputRank),
+    channelAxisAffineMap,
+    channelAxisAffineMap,
+    builder.getMultiDimIdentityMap(inputRank)
+  };
+  auto storedValue = builder.create<linalg::GenericOp>(
+      loc,
+      init.getType(),  // resultType
+      ValueRange{input, scales, zeroPoints},  // inputs
+      ValueRange{init},  // outputs
+      indexingMaps,
+      iteratorTypes,
+      [&](OpBuilder& builder, Location loc, ValueRange args) {
+        assert(args.size() == 4);
+        auto expressedValue = args[0];
+        auto scale = args[1];
+        auto zeroPoint = args[2];
+
+        auto result = convertRanked(builder, loc, op, expressedValue, {}, scale,
                                     zeroPoint, quantizedType);
+
+        builder.create<linalg::YieldOp>(loc, result);
+      })
+      .getResult(0);
+
+  return storedValue;
+}
+
+Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
+                        Value input,
+                        UniformQuantizedPerAxisType quantizedType) {
+  // Flatten unranked tensor into a 3D ranked tensor if necessary
+  bool isUnranked = isa<UnrankedTensorType>(input.getType());
+  int64_t channelAxis = quantizedType.getQuantizedDimension();
+  int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+  Value inputShape;
+  if (isUnranked) {
+    std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+        builder, loc, input, channelAxis, channelAxisSize);
+    channelAxis = 1;
   }
 
-  Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
-                        UniformQuantizedType quantizedType) const {
-    // Flatten input if unranked
-    bool isUnranked = isa<UnrankedTensorType>(input.getType());
-    Value inputShape;
-    if (isUnranked)
-      std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+  // Work on a ranked tensor
+  auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
+                                        channelAxis);
 
-    // Process ranked tensor
-    auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
+  // Restore original tensor shape if unranked
+  if (isUnranked)
+    result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
 
-    // Restore original shape if unranked
-    if (isUnranked)
-      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+  return result;
+}
 
-    return result;
-  }
+Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
+                       Value input, Type quantizedType) {
+  if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
+    return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
 
-  Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
-                                UniformQuantizedPerAxisType quantizedType,
-                                int64_t channelAxis) const {
-    auto *context = builder.getContext();
-
-    auto inputType = cast<RankedTensorType>(input.getType());
-    auto inputRank = inputType.getRank();
-
-    auto scales = materializePerChannelScales(builder, loc, quantizedType);
-    auto zeroPoints =
-        materializePerChannelZeroPoints(builder, loc, quantizedType);
-
-    auto storageType = quantizedType.getStorageType();
-    auto initShape = tensor::getMixedSizes(builder, loc, input);
-    Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
-
-    SmallVector<utils::IteratorType> iteratorTypes(
-        inputRank, utils::IteratorType::parallel);
-    auto channelAxisAffineMap = AffineMap::get(
-        inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
-    SmallVector<AffineMap> indexingMaps{
-      builder.getMultiDimIdentityMap(inputRank),
-      channelAxisAffineMap,
-      channelAxisAffineMap,
-      builder.getMultiDimIdentityMap(inputRank)
-    };
-    auto storedValue = builder.create<linalg::GenericOp>(
-        loc,
-        init.getType(),  // resultType
-        ValueRange{input, scales, zeroPoints},  // inputs
-        ValueRange{init},  // outputs
-        indexingMaps,
-        iteratorTypes,
-        [&](OpBuilder& builder, Location loc, ValueRange args) {
-          assert(args.size() == 4);
-          auto expressedValue = args[0];
-          auto scale = args[1];
-          auto zeroPoint = args[2];
-
-          auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
-                                               scale, zeroPoint, quantizedType);
-
-          builder.create<linalg::YieldOp>(loc, result);
-        })
-        .getResult(0);
-
-    return storedValue;
-  }
+  if (auto uniformQuantizedPerAxisType =
+          dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
+    return convertPerChannel(builder, loc, op, input,
+                             uniformQuantizedPerAxisType);
 
-  Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
-                          UniformQuantizedPerAxisType quantizedType) const {
-    // Flatten unranked tensor into a 3D ranked tensor if necessary
-    bool isUnranked = isa<UnrankedTensorType>(input.getType());
-    int64_t channelAxis = quantizedType.getQuantizedDimension();
-    int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
-    Value inputShape;
-    if (isUnranked) {
-      std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
-          builder, loc, input, channelAxis, channelAxisSize);
-      channelAxis = 1;
-    }
-
-    // Work on a ranked tensor
-    auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
-                                          channelAxis);
-
-    // Restore original tensor shape if unranked
-    if (isUnranked)
-      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
-
-    return result;
-  }
+  llvm_unreachable("unexpected quantized type");
+}
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
 
-public:
+struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
   using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
 
   LogicalResult
@@ -434,19 +477,7 @@ class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeC
     input = rewriter.create<quant::StorageCastOp>(
         loc, storageScalarOrTensorType, input);
 
-    // Flatten unranked tensor input
-    Value result;
-    if (auto uniformQuantizedType =
-            dyn_cast<UniformQuantizedType>(quantizedType)) {
-      result = convertPerLayer(rewriter, loc, input, uniformQuantizedType);
-    } else if (auto uniformQuantizedPerAxisType =
-                   dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
-      result =
-          convertPerChannel(rewriter, loc, input, uniformQuantizedPerAxisType);
-    } else {
-      llvm_unreachable("unexpected quantized type");
-    }
-
+    auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -457,120 +488,7 @@ class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeC
 // QuantizeCastOp
 //===----------------------------------------------------------------------===//
 
-class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
-
-  Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
-                              UniformQuantizedType quantizedType) const {
-
-    // Create scale and zero point constants
-    auto expressedType = quantizedType.getExpressedType();
-    auto storageType = quantizedType.getStorageType();
-    auto scaleAttr =
-        builder.getFloatAttr(expressedType, quantizedType.getScale());
-    auto scale =
-        builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
-    auto zeroPointAttr =
-        builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
-    auto zeroPoint =
-        builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
-
-    auto inputShape = getScalarOrTensorShape(builder, loc, input);
-    return quantizeScalarOrTensor(builder, loc, input, inputShape, scale,
-                                  zeroPoint, quantizedType);
-  }
-
-  Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
-                        UniformQuantizedType quantizedType) const {
-    // Flatten input if unranked
-    bool isUnranked = isa<UnrankedTensorType>(input.getType());
-    Value inputShape;
-    if (isUnranked)
-      std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
-
-    // Process ranked tensor
-    auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
-
-    // Restore original shape if unranked
-    if (isUnranked)
-      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
-
-    return result;
-  }
-
-  Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
-                                UniformQuantizedPerAxisType quantizedType,
-                                int64_t channelAxis) const {
-    auto *context = builder.getContext();
-
-    auto inputType = cast<RankedTensorType>(input.getType());
-    auto inputRank = inputType.getRank();
-
-    auto scales = materializePerChannelScales(builder, loc, quantizedType);
-    auto zeroPoints =
-        materializePerChannelZeroPoints(builder, loc, quantizedType);
-
-    auto storageType = quantizedType.getStorageType();
-    auto initShape = tensor::getMixedSizes(builder, loc, input);
-    Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
-
-    SmallVector<utils::IteratorType> iteratorTypes(
-        inputRank, utils::IteratorType::parallel);
-    auto channelAxisAffineMap = AffineMap::get(
-        inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
-    SmallVector<AffineMap> indexingMaps{
-      builder.getMultiDimIdentityMap(inputRank),
-      channelAxisAffineMap,
-      channelAxisAffineMap,
-      builder.getMultiDimIdentityMap(inputRank)
-    };
-    auto storedValue = builder.create<linalg::GenericOp>(
-        loc,
-        init.getType(),  // resultType
-        ValueRange{input, scales, zeroPoints},  // inputs
-        ValueRange{init},  // outputs
-        indexingMaps,
-        iteratorTypes,
-        [&](OpBuilder& builder, Location loc, ValueRange args) {
-          assert(args.size() == 4);
-          auto expressedValue = args[0];
-          auto scale = args[1];
-          auto zeroPoint = args[2];
-
-          auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
-                                               scale, zeroPoint, quantizedType);
-
-          builder.create<linalg::YieldOp>(loc, result);
-        })
-        .getResult(0);
-
-    return storedValue;
-  }
-
-  Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
-                          UniformQuantizedPerAxisType quantizedType) const {
-    // Flatten unranked tensor into a 3D ranked tensor if necessary
-    bool isUnranked = isa<UnrankedTensorType>(input.getType());
-    int64_t channelAxis = quantizedType.getQuantizedDimension();
-    int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
-    Value inputShape;
-    if (isUnranked) {
-      std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
-          builder, loc, input, channelAxis, channelAxisSize);
-      channelAxis = 1;
-    }
-
-    // Work on a ranked tensor
-    auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
-                                          channelAxis);
-
-    // Restore original tensor shape if unranked
-    if (isUnranked)
-      result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
-
-    return result;
-  }
-
-public:
+struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
   using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
 
   LogicalResult
@@ -578,18 +496,10 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto input = op.getInput();
-    auto resultScalarType = getScalarType(op.getResult().getType());
+    auto quantizedType = getScalarType(op.getResult().getType());
 
     // Flatten unranked tensor input
-    Value result;
-    if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
-      result = convertPerLayer(rewriter, loc, input, quantizedType);
-    } else if (auto quantizedType =
-                   dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
-      result = convertPerChannel(rewriter, loc, input, quantizedType);
-    } else {
-      llvm_unreachable("unexpected quantized type");
-    }
+    auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
 
     // Cast stored value to result quantized value
     rewriter.replaceOpWithNewOp<quant::StorageCastOp>(

>From 6eabe11b9880d2e22b2d4077d4f2639845219cce Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 17:29:17 -0400
Subject: [PATCH 13/18] Unit test for 'quant.dcast' lowering with bug fixes

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 172 ++++++++++--
 mlir/test/Dialect/Quant/lower-quant-ops.mlir  | 255 ++++++++++++++----
 2 files changed, 360 insertions(+), 67 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 77be3dd11d3ad..4adeb9218ff8e 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -76,6 +76,19 @@ Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
   return tensorConstant;
 }
 
+// Reshape an unranked tensor into a 1D ranked tensor.
+//
+// - input
+//   Unranked tensor.
+//
+// Return values:
+//
+// - flatInput
+//   1D ranked, dynamically shaped tensor.
+//
+// - inputShape
+//   1D extent tensor containing the shape of the original unranked input.
+//
 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
                                               Value input) {
   // Get unranked input shape and total size
@@ -100,6 +113,28 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
   return std::make_pair(flatInput, inputShape);
 }
 
+// Reshape an unranked tensor into a 3D ranked tensor where the central
+// dimension of the result tensor corresponds to dimension 'axis' of the input
+// tensor.
+//
+// - input
+//   Unranked tensor.
+//
+// - axis
+//   Index of the input dimension around which other input dimiensions will be
+//   collapsed.
+//
+// - axisSize
+//   Size of input dimension 'axis'.
+//
+// Return values:
+//
+// - flatInput
+//   3D ranked tensor of shape [?, axisSize, ?].
+//
+// - inputShape
+//   1D extent tensor containing the shape of the original unranked input.
+//
 std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
                                                         Location loc,
                                                         Value input,
@@ -142,6 +177,14 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
   return std::make_pair(flatInput, inputShape);
 }
 
+// Reshape an input tensor into its original unranked shape.
+//
+// - input
+//   Ranked tensor.
+//
+// - inputShape
+//   1D extent tensor.
+//
 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
                                  Value inputShape) {
   auto inputType = cast<RankedTensorType>(input.getType());
@@ -150,6 +193,15 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
   return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
 }
 
+// Create a tensor constant containing all scales in a per-channel quantized
+// type. Example:
+//
+//   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+//   %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
+//
 Value materializePerChannelScales(OpBuilder &builder, Location loc,
                                   UniformQuantizedPerAxisType quantizedType) {
   auto scales = quantizedType.getScales();
@@ -162,6 +214,15 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
 }
 
+// Create a tensor constant containing all zero points in a per-channel
+// quantized type. Example:
+//
+//   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+//   %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
+//
 Value materializePerChannelZeroPoints(
     OpBuilder &builder, Location loc,
     UniformQuantizedPerAxisType quantizedType) {
@@ -178,6 +239,19 @@ Value materializePerChannelZeroPoints(
   return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
 }
 
+// Clamp the given scalar or tensor input using the storage bounds encoded in
+// the given quantized type, if present.
+//
+// - input
+//   Scalar or ranked tensor input. The element type must match the storage type
+//   of 'quantizedType'.
+//
+// - inputShape
+//   If 'input' is a tensor, combination of attributes/values representing its
+//   static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - quantizedType
+//   Per-axis or per-channel quantized type.
 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
                           ArrayRef<OpFoldResult> inputShape,
                           QuantizedType quantizedType) {
@@ -209,6 +283,7 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
   return input;
 }
 
+// Emit op 'arith.fptosi' or 'arith.fptoui'.
 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
                             Type resultType, bool isSigned) {
   if (isSigned)
@@ -216,6 +291,7 @@ Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
   return builder.create<arith::FPToUIOp>(loc, resultType, input);
 }
 
+// Emit op 'arith.sitofp' or 'arith.uitofp'.
 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
                             Type resultType, bool isSigned) {
   if (isSigned)
@@ -225,6 +301,8 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
 
 // Quantize a scalar or ranked tensor value. The stored value is clamped using 
 // the storage bounds encoded in the given quantized type.
+//
+// See function 'convertRanked()' below for a description of the arguments.
 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
                     ArrayRef<OpFoldResult> inputShape, Value scale,
                     Value zeroPoint, QuantizedType quantizedType) {
@@ -266,7 +344,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
   return storedValueClamped;
 }
 
-// Dequantize a scalar or ranked tensor value.
+// Dequantize a scalar or ranked tensor input.
+//
+// See function 'convertRanked()' below for a description of the arguments.
 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
                       ArrayRef<OpFoldResult> inputShape, Value scale,
                       Value zeroPoint, QuantizedType quantizedType) {
@@ -310,10 +390,10 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
 //   static/dynamic dimensions. If 'input' is a scalar, empty list.
 //
 // - scale
-//   Scale as a scalar value.
+//   Scale as a floating-point scalar value.
 //
 // - zeroPoint
-//   Zero point as a scalar value.
+//   Zero point as an integer scalar value.
 //
 // - quantizedType
 //   Scalar quantized type of the result ('quant.qcast') or of the input
@@ -331,9 +411,20 @@ Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
   llvm_unreachable("unexpected quant op");
 }
 
+// Convert an operation using per-layer quantization with a scalar or ranked
+// tensor input.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar or ranked tensor.
+//
+// - quantizedType
+//   Per-layer quantized type.
+//
 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
                             Value input, UniformQuantizedType quantizedType) {
-
   // Create scale and zero point constants
   auto expressedType = quantizedType.getExpressedType();
   auto storageType = quantizedType.getStorageType();
@@ -350,6 +441,17 @@ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
                        quantizedType);
 }
 
+// Convert an operation using per-layer quantization.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+//   Per-layer quantized type.
+//
 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
                       Value input, UniformQuantizedType quantizedType) {
   // Flatten input if unranked
@@ -368,6 +470,18 @@ Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
   return result;
 }
 
+// Convert an operation using per-channel quantization and a scalar or ranked
+// tensor as an input.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar or ranked tensor.
+//
+// - quantizedType
+//   Per-channel quantized type.
+//
 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
                               Value input,
                               UniformQuantizedPerAxisType quantizedType,
@@ -381,9 +495,11 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
   auto zeroPoints =
       materializePerChannelZeroPoints(builder, loc, quantizedType);
 
-  auto storageType = quantizedType.getStorageType();
+  auto elementType = isa<FloatType>(inputType.getElementType())
+                         ? quantizedType.getStorageType()
+                         : quantizedType.getExpressedType();
   auto initShape = tensor::getMixedSizes(builder, loc, input);
-  Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+  Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
 
   SmallVector<utils::IteratorType> iteratorTypes(
       inputRank, utils::IteratorType::parallel);
@@ -395,7 +511,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
     channelAxisAffineMap,
     builder.getMultiDimIdentityMap(inputRank)
   };
-  auto storedValue = builder.create<linalg::GenericOp>(
+  auto result = builder.create<linalg::GenericOp>(
       loc,
       init.getType(),  // resultType
       ValueRange{input, scales, zeroPoints},  // inputs
@@ -404,20 +520,31 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
       iteratorTypes,
       [&](OpBuilder& builder, Location loc, ValueRange args) {
         assert(args.size() == 4);
-        auto expressedValue = args[0];
+        auto input = args[0];
         auto scale = args[1];
         auto zeroPoint = args[2];
 
-        auto result = convertRanked(builder, loc, op, expressedValue, {}, scale,
+        auto result = convertRanked(builder, loc, op, input, {}, scale,
                                     zeroPoint, quantizedType);
 
         builder.create<linalg::YieldOp>(loc, result);
       })
       .getResult(0);
 
-  return storedValue;
+  return result;
 }
 
+// Convert an operation using per-channel quantization.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+//   Per-channel quantized type.
+//
 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
                         Value input,
                         UniformQuantizedPerAxisType quantizedType) {
@@ -443,6 +570,19 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
   return result;
 }
 
+// Convert a quantization operation.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor, or unranked tensor. The element type matches
+//   the storage type (quant.dcast) or expressed type (quant.qcast) of
+//   'quantizedType'.
+//
+// - quantizedType
+//   Per-layer or per-channel quantized type.
+//
 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
                        Value input, Type quantizedType) {
   if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
@@ -456,10 +596,7 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
   llvm_unreachable("unexpected quantized type");
 }
 
-//===----------------------------------------------------------------------===//
-// DequantizeCastOp
-//===----------------------------------------------------------------------===//
-
+// Lowering pattern for 'quant.dcast'
 struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
   using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
 
@@ -478,16 +615,13 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
         loc, storageScalarOrTensorType, input);
 
     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
+
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
-
-//===----------------------------------------------------------------------===//
-// QuantizeCastOp
-//===----------------------------------------------------------------------===//
-
+// Lowering pattern for 'quant.qcast'
 struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
   using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
 
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 1030d5b20620b..6bba9f5c03772 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -1,5 +1,209 @@
 // RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s
 
+// CHECK-LABEL: @dcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<u8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<i8>
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<i8> to tensor<f32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<f32>
+// CHECK: return %[[EXPRESSED]] : tensor<f32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_0d(%arg0: tensor<!qalias>) -> tensor<f32> {
+  %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<3x?x5xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> {
+  %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32>
+  return %0 : tensor<3x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<*xi8>
+// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor<?xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<?xi8> to tensor<?xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>> to tensor<4x?x?x5xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK:   linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<4x?x?x5xf32>
+// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> {
+  %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32>
+  return %0 : tensor<4x?x?x5xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>> to tensor<*xi8>
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor<?x3x?xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?x3x?xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor<?x3x?xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xi8>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK:   linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<?x3x?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @qcast_per_layer_scalar
 // CHECK-SAME: %[[ARG_0:.*]]: f32
 
@@ -219,7 +423,7 @@ func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
 
 !qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
 func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> {
-  %0 = "quant.qcast"(%arg0) : (tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias>
+  %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias>
   return %0 : tensor<4x?x?x5x!qalias>
 }
 
@@ -253,7 +457,7 @@ func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x
 
 !qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
 func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
-  %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias>
+  %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias>
   return %0 : tensor<4x2x5x!qalias>
 }
 
@@ -301,52 +505,7 @@ func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4
 
 !qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
 func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
-  %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
   return %0 : tensor<*x!qalias>
 }
 
-// -----
-
-// CHECK-LABEL: @dcast_per_layer_scalar
-// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
-
-// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
-
-// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
-// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
-// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
-
-// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
-// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
-// CHECK: return %[[EXPRESSED]] : f32
-
-!qalias = !quant.uniform<i8:f32, 2.0:10>
-func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
-  %0 = quant.dcast %arg0 : !qalias to f32
-  return %0 : f32
-}
-
-// -----
-
-// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
-// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
-
-// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
-
-// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
-
-// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
-// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
-
-// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
-// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
-// CHECK: return %[[EXPRESSED]] : f32
-
-!qalias = !quant.uniform<u8:f32, 2.0:10>
-func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
-  %0 = quant.dcast %arg0 : !qalias to f32
-  return %0 : f32
-}
-

>From 647b8ec64d01af7606c29a0f90ad98b49ba8fa95 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 23 Jul 2024 15:48:56 -0400
Subject: [PATCH 14/18] Verifiers for 'quant.dcast' and 'quant.qcast'

---
 mlir/include/mlir/Dialect/Quant/IR/Quant.h    |   2 +
 .../mlir/Dialect/Quant/IR/QuantBase.td        |  96 ++++++-----
 .../include/mlir/Dialect/Quant/IR/QuantOps.td |  40 ++++-
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        | 105 +++++++++++-
 mlir/test/Dialect/Quant/invalid.mlir          | 161 ++++++++++++++++++
 mlir/test/Dialect/Quant/ops.mlir              |  96 +++++++++++
 6 files changed, 443 insertions(+), 57 deletions(-)
 create mode 100644 mlir/test/Dialect/Quant/invalid.mlir
 create mode 100644 mlir/test/Dialect/Quant/ops.mlir

diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
index c5ca88ec69795..11a969a3ee519 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -24,7 +24,9 @@
 namespace mlir {
 namespace quant {
 
+class QuantizedType;
 class UniformQuantizedType;
+class UniformQuantizedPerAxisType;
 
 } // namespace quant
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index e465d855c1986..8ef89a0d1393b 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -22,53 +22,63 @@ def Quant_Dialect : Dialect {
   let useDefaultTypePrinterParser = 1;
 }
 
+
 //===----------------------------------------------------------------------===//
-// Quantization type definitions
+// Type definitions
 //===----------------------------------------------------------------------===//
 
-class quant_TypedPrimitiveOrContainer<Type etype> :
-    Type<Or<[etype.predicate,
-                TensorOf<[etype]>.predicate,
-                VectorOf<[etype]>.predicate]>,
-         "primitive/tensor/vector of " # etype.summary>;
+class quant_ScalarOrTensorOf<Type etype> :
+    Type<Or<[etype.predicate, TensorOf<[etype]>.predicate]>,
+         "scalar or tensor of " # etype.summary>;
 
-// An implementation of QuantizedType.
 def quant_QuantizedType :
-    Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "QuantizedType">;
-
-// A primitive type that can represent a real value. This is either a
-// floating point value or a quantized type.
-def quant_RealPrimitiveType :
-    Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
-    "real valued primitive (float or quantized type)">;
-
-// A primitive type that can represent a storage value. This is either an
-// integer or quantized type.
-def quant_StoragePrimitiveType :
-    Type<Or<[AnySignlessInteger.predicate, quant_QuantizedType.predicate]>,
-    "quantized storage primitive (integer or quantized type)">;
-
-// A primitive or container of RealPrimitiveType.
-def quant_RealValueType :
-    quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
-
-// A primitive or container of StoragePrimitiveType.
-def quant_StorageValueType :
-    quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
-
-// Either a real valued or storage primitive or container type.
-def quant_RealOrStorageValueType :
-    Type<Or<[quant_RealValueType.predicate, quant_StorageValueType.predicate]>,
-    "real valued or storage primitive or container type">;
-
-// An implementation of UniformQuantizedType.
-def quant_UniformQuantizedType :
-    DialectType<Quant_Dialect,
-                CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
-                "UniformQuantizedType">;
-
-// Predicate for detecting a container or primitive of UniformQuantizedType.
-def quant_UniformQuantizedValueType :
-    quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
+    Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "quantized type">;
+
+def quant_ScalarType :
+    Type<Or<[
+      AnyInteger.predicate,
+      AnyFloat.predicate,
+      quant_QuantizedType.predicate
+    ]>, "integer, float, or quantized scalar">;
+
+def quant_IntegerOrQuantizedType :
+    Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>>;
+
+def quant_FloatScalarOrTensor :
+    quant_ScalarOrTensorOf<AnyFloat>;
+
+def quant_IntegerScalarOrTensor :
+    quant_ScalarOrTensorOf<AnyInteger>;
+
+def quant_QuantizedScalarOrTensor :
+    quant_ScalarOrTensorOf<quant_QuantizedType>;
+
+def quant_IntegerOrQuantizedScalarOrTensor :
+    quant_ScalarOrTensorOf<quant_IntegerOrQuantizedType>;
+
+
+//===----------------------------------------------------------------------===//
+// Traits
+//===----------------------------------------------------------------------===//
+
+def quant_SameScalarOrTensorShape :
+    PredOpTrait<
+      "input and result are both scalars or both tensors with matching shape",
+      Or<[
+        And<[
+          TypeIsPred<"input", quant_ScalarType>,
+          TypeIsPred<"result", quant_ScalarType>
+        ]>,
+        And<[
+          TypeIsPred<"input", AnyUnrankedTensor>,
+          TypeIsPred<"result", AnyUnrankedTensor>
+        ]>,
+        And<[
+          TypeIsPred<"input", AnyRankedTensor>,
+          TypeIsPred<"result", AnyRankedTensor>,
+          AllShapesMatch<["input", "result"]>.predicate
+        ]>
+      ]>
+    >;
 
 #endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 7a6d270dbb6e9..3a24c8148d1f5 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -28,7 +28,9 @@ class quant_Op<string mnemonic, list<Trait> traits> :
 // Quantization casts
 //===----------------------------------------------------------------------===//
 
-def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
+def quant_DequantizeCastOp : quant_Op<"dcast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
   let summary = "convert back from a quantized to quantizable (expressed) type operation";
   let description = [{
     A DequantizeCast op `dcast` represents the inverse of a `qcast`,
@@ -42,12 +44,22 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
     all operands to ops that must operate with the expressed type (typically
     math ops prior to lowering to target-specific, quantized kernels).
   }];
-  let arguments = (ins quant_RealValueType:$input);
-  let results = (outs quant_RealValueType:$result);
+  let arguments = (ins quant_QuantizedScalarOrTensor:$input);
+  let results = (outs quant_FloatScalarOrTensor:$result);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Return the float type of the scalar or tensor result.
+    FloatType getFloatType();
+    
+    /// Return the quantized type of the scalar or tensor input.
+    quant::QuantizedType getQuantizedType();
+  }];
 }
 
-def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
+def quant_QuantizeCastOp : quant_Op<"qcast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
   let summary = "convert a quantizable type to a quantized type";
   let description = [{
     A QuantizeCast `qcast` represents a potential type shift from a quantizable
@@ -71,12 +83,22 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
     it is legal to use a quantized representation (but is not known to be
     acceptable).
   }];
-  let arguments = (ins quant_RealValueType:$input);
-  let results = (outs quant_RealValueType:$result);
+  let arguments = (ins quant_FloatScalarOrTensor:$input);
+  let results = (outs quant_QuantizedScalarOrTensor:$result);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Return the float type of the scalar or tensor input.
+    FloatType getFloatType();
+    
+    /// Return the quantized type of the scalar or tensor result.
+    quant::QuantizedType getQuantizedType();
+  }];
 }
 
-def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
+def quant_StorageCastOp : quant_Op<"scast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
   let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
   let description = [{
     A StorageCast `scast` represents a cast from or to a type based on the
@@ -97,8 +119,8 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
     vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
     ```
   }];
-  let arguments = (ins quant_RealOrStorageValueType:$input);
-  let results = (outs quant_RealOrStorageValueType:$result);
+  let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
+  let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index e04ca7eb7e715..ef66ccf3d9e01 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/MathExtras.h"
@@ -28,13 +29,70 @@ namespace quant {
 
 namespace {
 
-Type getPrimitiveType(Type ty) {
-  if (auto tensorType = dyn_cast<TensorType>(ty))
-    return tensorType.getElementType();
-  return ty;
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+//   Any quantized type. Any quantized type with no per-axis quantization is
+//   ignored.
+//
+// - containerType
+//   Original input or result type of the operation using the provided quantized
+//   type. Used to ensure that the quantized type appears within a tensor and
+//   that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+                                        QuantizedType quantizedType,
+                                        Type containerType) {
+  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+  if (!quantizedPerAxisType)
+    return success();
+
+  auto tensorType = dyn_cast<TensorType>(containerType);
+  if (!tensorType)
+    return op->emitError("scalar types may not use per-axis quantization");
+
+  if (!tensorType.hasRank())
+    return success();
+
+  int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+  if (quantizedDimension >= tensorType.getRank())
+    return op->emitError("quantized dimension must be less than tensor rank");
+
+  int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
+  if (quantizedDimensionSize != ShapedType::kDynamic &&
+      quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+    return op->emitError(
+        "quantized dimension size does not match number of scales");
+
+  return success();
+}
+
+// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
+//
+// - quantizedType
+//   Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
+//   whether as a primitive type or in a tensor.
+//
+// - floatType
+//   Float type used in the input ('quant.qcast') or result ('quant.dcast'),
+//   whether as a primitive type or in a tensor.
+//
+// - containerType
+//   Type of original input or result.
+//
+LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
+                                   FloatType floatType, Type containerType) {
+  if (quantizedType.getExpressedType() != floatType)
+    return op->emitError(
+        "expressed type in quantized type expected to match float type");
+
+  if (failed(verifyPerAxisQuantization(op, quantizedType, containerType)))
+    return failure();
+
+  return success();
 }
 
-} // namespace
+}  // namespace
 
 
 //===----------------------------------------------------------------------===//
@@ -52,6 +110,24 @@ void QuantDialect::initialize() {
 }
 
 
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DequantizeCastOp::verify() {
+  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+                              getInput().getType());
+}
+
+FloatType DequantizeCastOp::getFloatType() {
+  return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+QuantizedType DequantizeCastOp::getQuantizedType() {
+  return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+
 //===----------------------------------------------------------------------===//
 // StorageCastOp
 //===----------------------------------------------------------------------===//
@@ -65,6 +141,25 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
   return srcScastOp.getInput();
 }
 
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult QuantizeCastOp::verify() {
+  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+                              getInput().getType());
+}
+
+FloatType QuantizeCastOp::getFloatType() {
+  return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+QuantizedType QuantizeCastOp::getQuantizedType() {
+  return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+
 } // namespace quant
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
new file mode 100644
index 0000000000000..9976ce5d00d65
--- /dev/null
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func.func @dcast_invalid_input(%arg0: f32) {
+  // expected-error at +1 {{operand #0 must be scalar or tensor of quantized type}}
+  %0 = quant.dcast %arg0 : f32 to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_invalid_result(%arg0: !qalias) {
+  // expected-error at +1 {{result #0 must be scalar or tensor of floating-point}}
+  %0 = quant.dcast %arg0 : !qalias to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.dcast %arg0 : !qalias to tensor<f32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor<!qalias>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<*xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<?x3xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_float_type_mismatch(%arg0: !qalias) {
+  // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+  %0 = quant.dcast %arg0 : !qalias to f64
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_scalar(%arg0: !qalias) {
+  // expected-error at +1 {{scalar types may not use per-axis quantization}}
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) {
+  // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+  %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) {
+  // expected-error at +1 {{quantized dimension size does not match number of scales}}
+  %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32>
+  return
+}
+
+// -----
+
+func.func @qcast_invalid_input(%arg0: f32) {
+  // expected-error at +1 {{result #0 must be scalar or tensor of quantized type}}
+  %0 = quant.qcast %arg0 : f32 to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_invalid_result(%arg0: !qalias) {
+  // expected-error at +1 {{operand #0 must be scalar or tensor of floating-point}}
+  %0 = quant.qcast %arg0 : !qalias to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_scalar_tensor(%arg0: tensor<f32>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.qcast %arg0 : tensor<f32> to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor<f32>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.qcast %arg0 : tensor<f32> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<?x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_float_type_mismatch(%arg0: f64) {
+  // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+  %0 = quant.qcast %arg0 : f64 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_scalar(%arg0: f32) {
+  // expected-error at +1 {{scalar types may not use per-axis quantization}}
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) {
+  // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+  %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) {
+  // expected-error at +1 {{quantized dimension size does not match number of scales}}
+  %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias>
+  return
+}
+
+
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
new file mode 100644
index 0000000000000..ab3d6decfb248
--- /dev/null
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_scalar(%arg0: !qalias) {
+  %0 = quant.dcast %arg0 : !qalias to f32
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_unranked(%arg0: tensor<*x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_dynamic(%arg0: tensor<?x?x?x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<?x?x?x!qalias> to tensor<?x?x?xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) {
+  %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_scalar(%arg0: f32) {
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) {
+  %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_unranked(%arg0: tensor<*xf32>) {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) {
+  %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_dynamic(%arg0: tensor<?x?x?xf32>) {
+  %0 = quant.qcast %arg0 : tensor<?x?x?xf32> to tensor<?x?x?x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) {
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return
+}
+

>From 09278de0b804824b0e6c1062e31d1a858549d23c Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 23 Jul 2024 17:56:36 -0400
Subject: [PATCH 15/18] Verifier for 'quant.scast' and unit tests

---
 .../mlir/Dialect/Quant/IR/QuantBase.td        | 28 +++++-
 .../include/mlir/Dialect/Quant/IR/QuantOps.td | 11 ++-
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        | 65 +++++++++----
 mlir/test/Dialect/Quant/invalid.mlir          | 97 +++++++++++++++++++
 mlir/test/Dialect/Quant/ops.mlir              | 55 +++++++++++
 5 files changed, 233 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index 8ef89a0d1393b..d81838db3dc1a 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -36,19 +36,24 @@ def quant_QuantizedType :
 
 def quant_ScalarType :
     Type<Or<[
-      AnyInteger.predicate,
+      AnySignlessInteger.predicate,
       AnyFloat.predicate,
       quant_QuantizedType.predicate
-    ]>, "integer, float, or quantized scalar">;
+    ]>,
+    "signless integer, float, or quantized scalar">;
 
 def quant_IntegerOrQuantizedType :
-    Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>>;
+    Type<Or<[
+      AnySignlessInteger.predicate,
+      quant_QuantizedType.predicate
+    ]>,
+    "signless integer or quantized type">;
 
 def quant_FloatScalarOrTensor :
     quant_ScalarOrTensorOf<AnyFloat>;
 
 def quant_IntegerScalarOrTensor :
-    quant_ScalarOrTensorOf<AnyInteger>;
+    quant_ScalarOrTensorOf<AnySignlessInteger>;
 
 def quant_QuantizedScalarOrTensor :
     quant_ScalarOrTensorOf<quant_QuantizedType>;
@@ -81,4 +86,19 @@ def quant_SameScalarOrTensorShape :
       ]>
     >;
 
+def quant_IntegerAndQuantizedCombination :
+    PredOpTrait<
+      "input must be integer and result must be quantized, or vice versa",
+      Or<[
+        And<[
+          TypeIsPred<"input", quant_QuantizedScalarOrTensor>,
+          TypeIsPred<"result", quant_IntegerScalarOrTensor>
+        ]>,
+        And<[
+          TypeIsPred<"input", quant_IntegerScalarOrTensor>,
+          TypeIsPred<"result", quant_QuantizedScalarOrTensor>
+        ]>
+      ]>
+    >;
+
 #endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 3a24c8148d1f5..5dab02e8e1ee5 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -98,7 +98,8 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [
 
 def quant_StorageCastOp : quant_Op<"scast", [
     Pure,
-    quant_SameScalarOrTensorShape]> {
+    quant_SameScalarOrTensorShape,
+    quant_IntegerAndQuantizedCombination]> {
   let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
   let description = [{
     A StorageCast `scast` represents a cast from or to a type based on the
@@ -122,7 +123,15 @@ def quant_StorageCastOp : quant_Op<"scast", [
   let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
   let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
   let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// Return the integer type used either in the input or the result.
+    IntegerType getIntegerType();
+    
+    /// Return the quantized type used either in the input or the result.
+    quant::QuantizedType getQuantizedType();
+  }];
 }
 
 #endif // QUANT_OPS
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index ef66ccf3d9e01..f722eb8e30806 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -86,10 +86,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
     return op->emitError(
         "expressed type in quantized type expected to match float type");
 
-  if (failed(verifyPerAxisQuantization(op, quantizedType, containerType)))
-    return failure();
-
-  return success();
+  // Veriy integrity of per-axis quantization information, if present.
+  return verifyPerAxisQuantization(op, quantizedType, containerType);
 }
 
 }  // namespace
@@ -128,20 +126,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() {
 }
 
 
-//===----------------------------------------------------------------------===//
-// StorageCastOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
-  // Matches x -> [scast -> scast] -> y, replacing the second scast with the
-  // value of x if the casts invert each other.
-  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
-  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
-    return OpFoldResult();
-  return srcScastOp.getInput();
-}
-
-
 //===----------------------------------------------------------------------===//
 // QuantizeCastOp
 //===----------------------------------------------------------------------===//
@@ -160,6 +144,51 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
 }
 
 
+//===----------------------------------------------------------------------===//
+// StorageCastOp
+//===----------------------------------------------------------------------===//
+
+IntegerType StorageCastOp::getIntegerType() {
+  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+  if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
+    return integerType;
+
+  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+  return cast<IntegerType>(resultScalarType);
+}
+
+QuantizedType StorageCastOp::getQuantizedType() {
+  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+  if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
+    return quantizedType;
+
+  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+  return cast<QuantizedType>(resultScalarType);
+}
+
+OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
+  // Matches x -> [scast -> scast] -> y, replacing the second scast with the
+  // value of x if the casts invert each other.
+  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
+    return OpFoldResult();
+  return srcScastOp.getInput();
+}
+
+LogicalResult StorageCastOp::verify() {
+  auto quantizedType = getQuantizedType();
+  auto integerType = getIntegerType();
+  if (quantizedType.getStorageType() != integerType)
+    return emitError(
+        "storage type in quantized type expected to match integer type");
+
+  // Verify integrity of per-axis quantization information, if available. While
+  // the quantization type may appear in the input or the result, their tensor
+  // shapes are guaranteed to be identical at this point.
+  return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
+}
+
+
 } // namespace quant
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
index 9976ce5d00d65..ba3a8e312d96e 100644
--- a/mlir/test/Dialect/Quant/invalid.mlir
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -158,4 +158,101 @@ func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) {
   return
 }
 
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_input(%arg0: si32) {
+  // expected-error at +1 {{operand #0 must be scalar or tensor of signless integer or quantized type}}
+  %0 = quant.scast %arg0 : si32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_result(%arg0: !qalias) {
+  // expected-error at +1 {{result #0 must be scalar or tensor of signless integer or quantized type}}
+  %0 = quant.scast %arg0 : !qalias to si32
+  return
+}
+
+// -----
+
+func.func @scast_both_integers(%arg0: i8) {
+  // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+  %0 = quant.scast %arg0 : i8 to i8
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_both_quantized(%arg0: !qalias) {
+  // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+  %0 = quant.scast %arg0 : !qalias to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_scalar_tensor(%arg0: tensor<i8>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.scast %arg0 : tensor<i8> to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor<i8>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.scast %arg0 : tensor<i8> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) {
+  // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+  %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<?x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_integer_type_mismatch(%arg0: i32) {
+  // expected-error at +1 {{storage type in quantized type expected to match integer type}}
+  %0 = quant.scast %arg0 : i32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_scalar(%arg0: i8) {
+  // expected-error at +1 {{scalar types may not use per-axis quantization}}
+  %0 = quant.scast %arg0 : i8 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) {
+  // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+  %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) {
+  // expected-error at +1 {{quantized dimension size does not match number of scales}}
+  %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias>
+  return
+}
 
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
index ab3d6decfb248..4abc5830d081e 100644
--- a/mlir/test/Dialect/Quant/ops.mlir
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -94,3 +94,58 @@ func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) {
   return
 }
 
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_scalar(%arg0: i8) {
+  %0 = quant.scast %arg0 : i8 to !qalias
+  %1 = quant.scast %0 : !qalias to i8
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) {
+  %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias>
+  %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_unranked(%arg0: tensor<*xi8>) {
+  %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+  %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) {
+  %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias>
+  %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_dynamic(%arg0: tensor<?x?x?xi8>) {
+  %0 = quant.scast %arg0 : tensor<?x?x?xi8> to tensor<?x?x?x!qalias>
+  %1 = quant.scast %0 : tensor<?x?x?x!qalias> to tensor<?x?x?xi8>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) {
+  %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+  %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+  return
+}
+
+

>From 764b1d5afaca8ac397d85157c54e2717fcefd0ac Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 24 Jul 2024 08:08:40 -0400
Subject: [PATCH 16/18] Canonicalization patterns for all ops and unit tests

---
 .../include/mlir/Dialect/Quant/IR/QuantOps.td |   2 +
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        |  67 ++++++---
 mlir/test/Dialect/Quant/canonicalize.mlir     | 134 +++++++++++++++---
 3 files changed, 164 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 5dab02e8e1ee5..036940119b349 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -48,6 +48,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [
   let results = (outs quant_FloatScalarOrTensor:$result);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
   let hasVerifier = 1;
+  let hasFolder = 1;
   let extraClassDeclaration = [{
     /// Return the float type of the scalar or tensor result.
     FloatType getFloatType();
@@ -87,6 +88,7 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [
   let results = (outs quant_QuantizedScalarOrTensor:$result);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
   let hasVerifier = 1;
+  let hasFolder = 1;
   let extraClassDeclaration = [{
     /// Return the float type of the scalar or tensor input.
     FloatType getFloatType();
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index f722eb8e30806..6a709488ce01c 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -117,6 +117,17 @@ LogicalResult DequantizeCastOp::verify() {
                               getInput().getType());
 }
 
+OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
+  // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
+  // with the value of x. Values x and y are guaranteed to be of the same type
+  // in this pattern.
+  auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
+  if (!srcQcastOp)
+    return {};
+  assert(srcQcastOp.getInput().getType() == getType());
+  return srcQcastOp.getInput();
+}
+
 FloatType DequantizeCastOp::getFloatType() {
   return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
 }
@@ -135,6 +146,18 @@ LogicalResult QuantizeCastOp::verify() {
                               getInput().getType());
 }
 
+OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
+  // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
+  // with the value of x if the casts invert each other. Contrary to the folding
+  // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
+  // x and y are not guaranteed to be of the same type here, as they may use
+  // different quantization parameters.
+  auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
+  if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
+    return {};
+  return srcDcastOp.getInput();
+}
+
 FloatType QuantizeCastOp::getFloatType() {
   return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
 }
@@ -148,6 +171,28 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
 // StorageCastOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult StorageCastOp::verify() {
+  auto quantizedType = getQuantizedType();
+  auto integerType = getIntegerType();
+  if (quantizedType.getStorageType() != integerType)
+    return emitError(
+        "storage type in quantized type expected to match integer type");
+
+  // Verify integrity of per-axis quantization information, if available. While
+  // the quantization type may appear in the input or the result, their tensor
+  // shapes are guaranteed to be identical at this point.
+  return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
+}
+
+OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
+  // Matches x -> quant.scast -> quant.scast -> y, replacing the second
+  // quant.scast with the value of x if the casts invert each other.
+  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
+    return {};
+  return srcScastOp.getInput();
+}
+
 IntegerType StorageCastOp::getIntegerType() {
   auto inputScalarType = getElementTypeOrSelf(getInput().getType());
   if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
@@ -166,28 +211,6 @@ QuantizedType StorageCastOp::getQuantizedType() {
   return cast<QuantizedType>(resultScalarType);
 }
 
-OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
-  // Matches x -> [scast -> scast] -> y, replacing the second scast with the
-  // value of x if the casts invert each other.
-  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
-  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
-    return OpFoldResult();
-  return srcScastOp.getInput();
-}
-
-LogicalResult StorageCastOp::verify() {
-  auto quantizedType = getQuantizedType();
-  auto integerType = getIntegerType();
-  if (quantizedType.getStorageType() != integerType)
-    return emitError(
-        "storage type in quantized type expected to match integer type");
-
-  // Verify integrity of per-axis quantization information, if available. While
-  // the quantization type may appear in the input or the result, their tensor
-  // shapes are guaranteed to be identical at this point.
-  return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
-}
-
 
 } // namespace quant
 } // namespace mlir
diff --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir
index 36c3eaf5e10d2..73c57e2a48212 100644
--- a/mlir/test/Dialect/Quant/canonicalize.mlir
+++ b/mlir/test/Dialect/Quant/canonicalize.mlir
@@ -1,24 +1,124 @@
 // RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
 
+// CHECK-LABEL: @dcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+  %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+  %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
 // -----
-// CHECK-LABEL: redundant_scast
-func.func @redundant_scast() -> tensor<4xi8> {
-  // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8>
-  // CHECK-NEXT: return
-  %cst = arith.constant dense<5> : tensor<4xi8>
-  %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
-  %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<4xi8>
-  %3 = arith.addi %2, %2 : tensor<4xi8>
-  return %3 : tensor<4xi8>
+
+// CHECK-LABEL: @dcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> {
+  %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias>
+  %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+  return %1 : tensor<4xf32>
 }
 
 // -----
-// CHECK-LABEL: non_redundant_scast
-func.func @non_redundant_scast() -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>> {
-  // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8>
-  // CHECK-NEXT: scast
-  // CHECK-NEXT: return
-  %cst = arith.constant dense<5> : tensor<4xi8>
-  %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
-  return %1 : tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
+
+// CHECK-LABEL: @qcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+  %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+  %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+  return %1 : tensor<4x!qalias>
 }
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> {
+  %0 = arith.negf %arg0 : tensor<4xf32>
+  %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+  return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+  %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+  %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1>
+  return %1 : tensor<4x!qalias1>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+  %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+  %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias>
+  return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]]
+// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]]
+// CHECK: return %[[SCAST]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> {
+  %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+  %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8>
+  return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+  %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+  %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1>
+  return %1 : tensor<4x!qalias1>
+}
+

>From 70c5af961b01b91b7d1664ebad7fa528a37b35af Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 25 Jul 2024 00:23:28 -0400
Subject: [PATCH 17/18] Pass 'strip-func-quant-types'

---
 .../mlir/Dialect/Quant/Transforms/Passes.td   |  15 +++
 .../Dialect/Quant/Transforms/CMakeLists.txt   |   1 +
 .../Quant/Transforms/StripFuncQuantTypes.cpp  | 114 ++++++++++++++++++
 .../Dialect/Quant/strip-func-quant-types.mlir |  88 ++++++++++++++
 4 files changed, 218 insertions(+)
 create mode 100644 mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
 create mode 100644 mlir/test/Dialect/Quant/strip-func-quant-types.mlir

diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index 56e10688b0c98..b25296d4db5a9 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -31,4 +31,19 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
   ];
 }
 
+def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
+  let summary = "Strip quantized types from function headers";
+  let description = [{
+    Identify occurrences of function arguments using a quantized type and
+    replace them with a new value of the corresponding storage (signless
+    integer) type. For each converted argument, a `quant.scast` op is introduced
+    at the head of the function's entry block converting the new integer
+    argument into the original quantized value.
+  }];
+  let dependentDialects = [
+    "func::FuncDialect",
+    "quant::QuantDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 2daea7750cfe3..662c3e368b624 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRQuantTransforms
   LowerQuantOps.cpp
+  StripFuncQuantTypes.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
new file mode 100644
index 0000000000000..8996eff61a39c
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -0,0 +1,114 @@
+//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Strips quantized types from function headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+class QuantizedTypeConverter : public TypeConverter {
+
+  static Type convertQuantizedType(QuantizedType quantizedType) {
+    return quantizedType.getStorageType();
+  }
+  
+  static Type convertTensorType(TensorType tensorType) {
+    if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+      return tensorType.clone(convertQuantizedType(quantizedType));
+    return tensorType;
+  }
+
+  static Value materializeConversion(OpBuilder &builder, Type type,
+                                     ValueRange inputs, Location loc) {
+    assert(inputs.size() == 1);
+    return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+  }
+
+public:
+
+  explicit QuantizedTypeConverter() {
+    addConversion([](Type type) { return type; });
+    addConversion(convertQuantizedType);
+    addConversion(convertTensorType);
+
+    addArgumentMaterialization(materializeConversion);
+    addSourceMaterialization(materializeConversion);
+    addTargetMaterialization(materializeConversion);
+  }
+};
+
+// Conversion pass
+class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+
+  // Return whether a type is considered legal when occurring in the header of
+  // a function or as an operand to a 'return' op.
+  static bool isLegalType(Type type) {
+    if (auto tensorType = dyn_cast<TensorType>(type))
+      return isLegalType(tensorType.getElementType());
+    return !isa<quant::QuantizedType>(type);
+  }
+
+public:
+
+  void runOnOperation() override {
+    
+    auto moduleOp = cast<ModuleOp>(getOperation());
+    auto* context = &getContext();
+
+    QuantizedTypeConverter typeConverter;
+    ConversionTarget target(*context);
+    RewritePatternSet patterns(context);
+
+    // Mark func.func, func.return, and func.call illegal if they contain any
+    // quantized types.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>(
+        [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
+    target.addDynamicallyLegalOp<func::CallOp>(
+        [&](func::CallOp op) { return typeConverter.isLegal(op); });
+
+    // Register conversion patterns
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
+
+    // Apply conversion
+    if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
+
diff --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
new file mode 100644
index 0000000000000..e5f0d4921bed3
--- /dev/null
+++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @strip_operands
+// CHECK-SAME: %[[ARG_0:.*]]: i8
+// CHECK-SAME: %[[ARG_1:.*]]: i16
+// CHECK-SAME: %[[ARG_2:.*]]: f32
+
+// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}>
+// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}>
+
+// CHECK: "test.custom_op"(%[[ARG_1_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_0_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_2]])
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) {
+  "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias>
+  "test.custom_op"(%arg1) : (!qalias1) -> tensor<?x!qalias1>
+  "test.custom_op"(%arg2) : (f32) -> tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_results
+// CHECK-SAME: tensor<4xi8>, tensor<?xi16>, tensor<*xi8>, tensor<4xf32>
+
+// CHECK: %[[RESULT_0:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[RESULT_1:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULT_2:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8>
+
+// CHECK: %[[RESULT_3:.*]] = "test.custom_op"()
+
+// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]]
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_results() -> (tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>) {
+  %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+  %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+  %2 = "test.custom_op"() : () -> tensor<*x!qalias>
+  %3 = "test.custom_op"() : () -> tensor<4xf32>
+  return %0, %1, %2, %3 : tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @callee
+// CHECK-SAME: (tensor<4xi8>, tensor<?xi16>) -> (tensor<*xi8>, tensor<4xf32>)
+
+// CHECK-LABEL: @strip_call
+
+// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]])
+
+// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>>
+// CHECK: "test.custom_op"(%[[RESULT_0_CAST]])
+
+// CHECK: "test.custom_op"(%[[RESULTS]]#1)
+
+// CHECK: return
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func private @callee(tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+
+func.func @strip_call() {
+  %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+  %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+  %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+  "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> ()
+  "test.custom_op"(%2#1) : (tensor<4xf32>) -> ()
+  return
+}

>From cce8171c6d016d823e514ec304f94d2e8c4085c0 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 25 Jul 2024 18:43:30 -0400
Subject: [PATCH 18/18] Dialect documentation progress

---
 .../mlir/Dialect/Quant/IR/QuantBase.td        |  53 +++++-
 .../include/mlir/Dialect/Quant/IR/QuantOps.td | 178 +++++++++++++-----
 2 files changed, 186 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index d81838db3dc1a..2690b4fe0b111 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -17,8 +17,59 @@ include "mlir/IR/OpBase.td"
 
 def Quant_Dialect : Dialect {
   let name = "quant";
+  let description = [{
+    ## Per-axis quantization integrity
+
+    When type `!quant.uniform` contains per-axis quantization information, the
+    rules below are enforced. These rules guarantee that the quantization
+    information encoded in the data type is applicable to the context in which
+    the quantized type is used. For efficiency, these rules are actively
+    enforced by the verifiers of `quant` dialect ops, but they must be
+    respected in any context in which the `!quant.uniform` data type is used,
+    such as the header of a `func.func` op, or the input of an arithmetic
+    operation.
+ 
+    - A quantized type with per-channel quantization information must be the
+      element type of a tensor container type, and may not occur directly as
+      the data type of a scalar value.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies per-channel quantization for a
+    // scalar type.
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>
+
+    // Correct. Type `!quant.uniform` with per-channel quantization is wrapped in
+    // a `tensor` type.
+    %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
+    ```
+
+    - If the tensor containing the `!quant.uniform` type is ranked, its rank
+      must be greater than the channel axis specified in the quantized type.
+
+    ```
+    // Incorrect. The tensor rank (2) is not greater than the channel axis in the
+    // quantized type (3).
+    %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>
+
+    // Correct. The tensor rank (2) is now greater than the channel axis (1):
+    %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
+    ```
+
+    - If the axis dimension in the containing tensor is static, its size must
+      be equal to the number of scales present in the quantized type.
+
+    ```
+    // Incorrect. The channel axis is 1, and the size of dimension 1 in the
+    // containing tensor is 3. However, there are 4 scale values present in the
+    // quantized type.
+    %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
+
+    // Correct. The quantized type now includes 3 scale values, matching the size
+    // of dimension 1 of the result tensor.
+    %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+    ```
+  }];
   let cppNamespace = "::mlir::quant";
-
   let useDefaultTypePrinterParser = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 036940119b349..52dfc6b051de7 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -31,18 +31,55 @@ class quant_Op<string mnemonic, list<Trait> traits> :
 def quant_DequantizeCastOp : quant_Op<"dcast", [
     Pure,
     quant_SameScalarOrTensorShape]> {
-  let summary = "convert back from a quantized to quantizable (expressed) type operation";
+  let summary = "Dequantize cast operation";
   let description = [{
-    A DequantizeCast op `dcast` represents the inverse of a `qcast`,
-    converting back from a quantized to quantizable (expressed) type.
+    Convert an input quantized value into its expressed floating-point value.
+    The dequantization process consists of the following steps:
 
-    Like `qcast`s, a `dcast` is allowed to have both its operand and result
-    as non quantized types. This facilitates transformations and marks edges
-    where the computation must be carried out in the expressed type.
+    ```
+    def dequantize(quantizedValue: quantizedType) -> expressedType:
+        storedValue = reinterpretCast(quantizedValue, storageType)
+        storedValueFloat = convertIntToFloat(storedValue, expressedType)
+        zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+        expressedValue = (storedValueFloat - zeroPointFloat) * scale
+        return expressedValue
+    ```
+
+    Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained
+    from the corresponding parameters encoded in `quantizedType`. For
+    per-channel quantization, the appropriate `scale` and `zeroPoint` values
+    are used for each tensor element computation according to the channel the
+    element belongs to.
+
+    The operation must satisfy the following syntactic constraints:
+
+    - Operand `input` must be a scalar or tensor of type `!quant.uniform`.
+
+    - The result type must be a floating-point scalar or tensor.
+
+    - The `expressedType` parameter of the `!quant.uniform` type of the input
+      must match the floating-point type of the result.
+
+    - The operand and result types must be both scalars or both tensors. If
+      tensors, they must be both ranked or both unranked. If ranked, both must
+      have the same shape, including matching static and dynamic dimensions.
+
+    - If the operand uses per-channel quantization, its `!quant.uniform` type
+      must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
+
+    ```
+    // Dequantize a scalar quantized value
+    %result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
+
+    // Dequantize a dynamically shaped tensor of quantized values
+    %result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
 
-    Especially early in transformation, it is common to have `dcast`s on
-    all operands to ops that must operate with the expressed type (typically
-    math ops prior to lowering to target-specific, quantized kernels).
+    // Dequantize an unranked tensor using per-axis quantization information
+    %result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
+    ```
   }];
   let arguments = (ins quant_QuantizedScalarOrTensor:$input);
   let results = (outs quant_FloatScalarOrTensor:$result);
@@ -61,28 +98,57 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [
 def quant_QuantizeCastOp : quant_Op<"qcast", [
     Pure,
     quant_SameScalarOrTensorShape]> {
-  let summary = "convert a quantizable type to a quantized type";
+  let summary = "Quantize cast operation";
   let description = [{
-    A QuantizeCast `qcast` represents a potential type shift from a quantizable
-    type to a quantized type.
-
-    At runtime, a `qcast` will apply the transformation expressed by its
-    operand and result type. For flexibility during transformation, it is also
-    possible to have a `qcast` that performs no transformation (both its
-    operand and result type are quantizable).
-
-    A `qcast` will typically originate from either:
-      a) An expressed or implied constraint in the source dialect which signals
-         that a certain level of quantization is possible or required.
-      b) An inference made by a quantization algorithm indicating that a
-         quantized representation may be acceptable.
-
-    Especially early in transformation, it is common to have pairs of
-    `qcast` and `dcast` at points where a transition to a quantized type is
-    required. In addition, it is also common to have an identity `qcast`
-    (where the operand and result type are not quantized) at all points where
-    it is legal to use a quantized representation (but is not known to be
-    acceptable).
+    Convert a floating-point value to a quantized type. The quantization
+    process consists of the following steps:
+
+    ```
+    def quantize(expressedValue: expressedType) -> quantizedType:
+        zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+        scaledValue = expressedValue / scale
+        storedValueFloat = scaledValue + zeroPointFloat
+        storedValue = convertFloatToInt(storedValueFloat, storageType)
+        storedValueClamped = clamp(storedValue, storageMin, storageMax)
+        quantizedValue = reinterpretCast(storedValueClamped, quantizedType)
+        return quantizedValue
+    ```
+
+    Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`,
+    and `zeroPoint` are obtained from the corresponding parameters encoded in
+    `quantizedType`. For per-channel quantization, the appropriate `scale` and
+    `zeroPoint` values are used for each tensor element computation according
+    to the channel the element belongs to.
+
+    The operation must satisfy the following syntactic constraints:
+
+    - Operand `input` must be a floating-point scalar or tensor.
+
+    - The result type must be a scalar or tensor of type `!quant.uniform`.
+
+    - The `expressedType` parameter in the `!quant.uniform` type of the result
+      must match the floating-point type of the input.
+
+    - The operand and result types must be both scalars or both tensors. If
+      tensors, they must be both ranked or both unranked. If ranked, both must
+      have the same shape, including matching static and dynamic dimensions.
+
+    - If the result uses per-channel quantization, its `!quant.uniform` type
+      must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
+
+    ```
+    // Quantize a scalar floating-point value
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32, 2.0>
+
+    // Quantize a dynamically shaped tensor of quantized values
+    %result = quant.qcast %input : tensor<?xf32> to tensor<?x!quant.uniform<i8:f32, 2.0>>
+
+    // Quantize an unranked tensor using per-axis quantization information
+    %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
+    ```
   }];
   let arguments = (ins quant_FloatScalarOrTensor:$input);
   let results = (outs quant_QuantizedScalarOrTensor:$result);
@@ -102,24 +168,48 @@ def quant_StorageCastOp : quant_Op<"scast", [
     Pure,
     quant_SameScalarOrTensorShape,
     quant_IntegerAndQuantizedCombination]> {
-  let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
+  let summary = "Storage cast operation";
   let description = [{
-    A StorageCast `scast` represents a cast from or to a type based on the
-    storage type and a type based on a corresponding quantized type.
+    Convert a value from a quantized type to the corresponding signless integer
+    storage type, or vice versa. This conversion simply involves a
+    reinterpretation of the input bits and does not involve any data
+    manipulation.
 
-    This op exists to ensure type coherency for between parts of the computation
-    which are operating directly on an underlying storage type and those which
-    operate on quantized values.
+    The following syntactic restrictions must be met:
+
+    - Operand `input` must be a scalar or tensor of a signless integer or
+      `!quant.uniform` type.
+
+    - The result must be a scalar or tensor of a signless integer or
+      `!quant.uniform` type.
+
+    - If the operand is a scalar or tensor of type integer, the result must be
+      a scalar or tensor of type `!quant.uniform`, and vice versa.
+
+    - The operand and result must be both scalars or both tensors. If tensors,
+      they must be both ranked or both unranked. If ranked, both must have the
+      same shape, including matching static and dynamic dimensions.
+
+    - The width of the `storageType` parameter of the quantized type of the
+      operand or result must match the width of the signless integer type of
+      the operand or result.
+
+    - If the operand or result uses per-channel quantization, its
+      `!quant.uniform` type must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
 
-    Examples from storage to quantized type:
-    ```
-    i8 -> !quant<"uniform[i8:f32]{1.0}">
-    ```
-    ```
-    tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
-    ```
     ```
-    vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+    // Cast a scalar quantized value into its storage type
+    %result = quant.scast %input : !quant.uniform<i8:f32, 2.0> to i8
+
+    // Cast a dynamically shaped tensor of quantized values into their storage type
+    %result = quant.scast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xi8>
+
+    // Cast an unranked tensor of signless integers into a quantized type using
+    // per-channel quantization
+    %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
     ```
   }];
   let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);



More information about the Mlir-commits mailing list