[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)
Rafael Ubal
llvmlistbot at llvm.org
Thu Jul 25 15:47:19 PDT 2024
https://github.com/rafaelubalmw created https://github.com/llvm/llvm-project/pull/100667
Full revamp of the 'quant' dialect. This is an implementation for the RFC at https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
>From 4de6f81e366e04b51b1ad1a22e911b5412302ec6 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 8 Jul 2024 17:58:38 -0400
Subject: [PATCH 01/18] File system restructure for 'quant' dialect
---
.../include/mlir/Dialect/Quant/CMakeLists.txt | 8 ++----
.../mlir/Dialect/Quant/IR/CMakeLists.txt | 6 +++++
.../Dialect/Quant/{QuantOps.h => IR/Quant.h} | 12 ++++-----
.../{QuantOpsBase.td => IR/QuantBase.td} | 8 +++---
.../Quant/{ => IR}/QuantDialectBytecode.td | 0
.../mlir/Dialect/Quant/{ => IR}/QuantOps.td | 8 +++---
.../mlir/Dialect/Quant/{ => IR}/QuantTypes.h | 6 ++---
.../Dialect/Quant/Transforms/CMakeLists.txt | 5 ++++
.../mlir/Dialect/Quant/Transforms/Passes.h | 27 +++++++++++++++++++
.../mlir/Dialect/Quant/Transforms/Passes.td | 26 ++++++++++++++++++
.../Quant/{ => Utils}/FakeQuantSupport.h | 8 +++---
.../Quant/{ => Utils}/UniformSupport.h | 8 +++---
.../mlir/Dialect/Tosa/Utils/QuantUtils.h | 4 +--
mlir/include/mlir/InitAllDialects.h | 2 +-
mlir/lib/CAPI/Dialect/Quant.cpp | 4 +--
.../Dialect/Quant/IR/QuantDialectBytecode.cpp | 6 ++---
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 10 +++----
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 4 +--
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 4 +--
.../Dialect/Quant/Utils/FakeQuantSupport.cpp | 4 +--
.../Dialect/Quant/Utils/UniformSupport.cpp | 2 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 2 +-
23 files changed, 113 insertions(+), 53 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
rename mlir/include/mlir/Dialect/Quant/{QuantOps.h => IR/Quant.h} (69%)
rename mlir/include/mlir/Dialect/Quant/{QuantOpsBase.td => IR/QuantBase.td} (93%)
rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantDialectBytecode.td (100%)
rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantOps.td (96%)
rename mlir/include/mlir/Dialect/Quant/{ => IR}/QuantTypes.h (99%)
create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
create mode 100644 mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
rename mlir/include/mlir/Dialect/Quant/{ => Utils}/FakeQuantSupport.h (93%)
rename mlir/include/mlir/Dialect/Quant/{ => Utils}/UniformSupport.h (97%)
diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
index c08f399ee182d..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
@@ -1,6 +1,2 @@
-add_mlir_dialect(QuantOps quant)
-add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
-
-set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
-mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
-add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..c08f399ee182d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(QuantOps quant)
+add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
+mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
+add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
similarity index 69%
rename from mlir/include/mlir/Dialect/Quant/QuantOps.h
rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h
index 14fb3035ab0d3..a703612d6b489 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -1,4 +1,4 @@
-//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
+//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_
-#define MLIR_DIALECT_QUANT_QUANTOPS_H_
+#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_
+#define MLIR_DIALECT_QUANT_IR_QUANT_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -19,9 +19,9 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
-#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
#define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
-#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_
+#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
similarity index 93%
rename from mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index da822d0a61deb..dadca06091b1e 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -1,4 +1,4 @@
-//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
+//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef DIALECT_QUANT_QUANT_OPS_BASE_
-#define DIALECT_QUANT_QUANT_OPS_BASE_
+#ifndef QUANT_BASE
+#define QUANT_BASE
include "mlir/IR/OpBase.td"
@@ -71,4 +71,4 @@ def quant_UniformQuantizedType :
def quant_UniformQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
-#endif // DIALECT_QUANT_QUANT_OPS_BASE_
+#endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
similarity index 100%
rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
similarity index 96%
rename from mlir/include/mlir/Dialect/Quant/QuantOps.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 7937265ce2f20..a3a0ff1608a66 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -10,10 +10,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef DIALECT_QUANT_QUANT_OPS_
-#define DIALECT_QUANT_QUANT_OPS_
+#ifndef QUANT_OPS
+#define QUANT_OPS
-include "mlir/Dialect/Quant/QuantOpsBase.td"
+include "mlir/Dialect/Quant/IR/QuantBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -100,4 +100,4 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
let hasFolder = 1;
}
-#endif // DIALECT_QUANT_QUANT_OPS_
+#endif // QUANT_OPS
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
similarity index 99%
rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h
rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index de5aed0a91a20..c020e1b46ad4e 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
-#define MLIR_DIALECT_QUANT_QUANTTYPES_H
+#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
+#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -412,4 +412,4 @@ class CalibratedQuantizedType
} // namespace quant
} // namespace mlir
-#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
+#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..30f7c1696bdb9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
+add_public_tablegen_target(MLIRQuantTransformsIncGen)
+
+add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
new file mode 100644
index 0000000000000..0b7378651afa1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
@@ -0,0 +1,27 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
new file mode 100644
index 0000000000000..f511c90ec6931
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -0,0 +1,26 @@
+//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> {
+ let summary = "Lower 'quant.dcast' and 'quant.qcast' ops.";
+ let description = [{
+ Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
+ into other core dialects.
+
+ The lowering process generates storage type casts in the form of
+ `quant.scast` ops to convert operands and results from quantized types to
+ the corresponding storage type, or vice versa.
+ let dependentDialects = ["quant::QuantDialect"];
+}
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
similarity index 93%
rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
index 367d468b2acf1..6551efc6242a6 100644
--- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
@@ -34,10 +34,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
-#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
namespace mlir {
namespace quant {
@@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
} // namespace quant
} // namespace mlir
-#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
similarity index 97%
rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
index 4119aced4c075..6773f45069c87 100644
--- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
-#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
#include <utility>
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
@@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter {
} // namespace quant
} // namespace mlir
-#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 298c97015fe2e..5e80745777b3b 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -16,8 +16,8 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
namespace mlir {
namespace tosa {
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 549c26c72d8a1..75e62cda90d45 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -64,7 +64,7 @@
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 0a7181d8bc17c..b30d1de73288c 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -8,8 +8,8 @@
#include "mlir-c/Dialect/Quant.h"
#include "mlir/CAPI/Registration.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index c0c00fb4893cb..0f4b755367495 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -9,8 +9,8 @@
#include "QuantDialectBytecode.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVector.h"
@@ -31,7 +31,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader,
return success();
}
-#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc"
/// This class implements the bytecode interface for the Quant dialect.
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c9a6bbc9ceeea..fa9725c23d643 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
#include "QuantDialectBytecode.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
@@ -24,14 +24,14 @@ using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
void QuantizationDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
>();
addBytecodeInterface(this);
}
@@ -46,4 +46,4 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
}
#define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be..15cde77e40afb 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantTypes.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 926a8a0aa13d5..c882a616f397c 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index 8c69729824691..fb27640bfd278 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
using namespace mlir;
using namespace mlir::quant;
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index 408701f80444a..62c7a7128d63a 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
#include "mlir/IR/BuiltinTypes.h"
#include <numeric>
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8687be075ea67..5d082ea9b1010 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -11,7 +11,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4337787e4aead..d3aff36a763b6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -14,7 +14,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
>From d6e9adcc3d8859f9dbedbf9360d59cc91a99b9ed Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 8 Jul 2024 19:26:06 -0400
Subject: [PATCH 02/18] Successfully compiled empty pass '-lower-quant-ops'
---
.../mlir/Dialect/Quant/IR/QuantBase.td | 4 +--
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 2 +-
.../mlir/Dialect/Quant/Transforms/Passes.td | 5 +--
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 2 +-
mlir/include/mlir/InitAllDialects.h | 2 +-
mlir/include/mlir/InitAllPasses.h | 2 ++
mlir/lib/CAPI/Dialect/Quant.cpp | 2 +-
mlir/lib/Dialect/Quant/CMakeLists.txt | 1 +
.../Dialect/Quant/IR/QuantDialectBytecode.cpp | 2 +-
.../Dialect/Quant/IR/QuantDialectBytecode.h | 4 +--
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 2 +-
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 2 +-
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 4 +--
.../Dialect/Quant/Transforms/CMakeLists.txt | 15 ++++++++
.../Quant/Transforms/LowerQuantOps.cpp | 34 +++++++++++++++++++
15 files changed, 68 insertions(+), 15 deletions(-)
create mode 100644 mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index dadca06091b1e..e465d855c1986 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -15,7 +15,7 @@
include "mlir/IR/OpBase.td"
-def Quantization_Dialect : Dialect {
+def Quant_Dialect : Dialect {
let name = "quant";
let cppNamespace = "::mlir::quant";
@@ -63,7 +63,7 @@ def quant_RealOrStorageValueType :
// An implementation of UniformQuantizedType.
def quant_UniformQuantizedType :
- DialectType<Quantization_Dialect,
+ DialectType<Quant_Dialect,
CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
"UniformQuantizedType">;
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index a3a0ff1608a66..ba282d50328da 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
class quant_Op<string mnemonic, list<Trait> traits> :
- Op<Quantization_Dialect, mnemonic, traits>;
+ Op<Quant_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Quantization casts
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index f511c90ec6931..e43062a98b1ea 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -11,7 +11,7 @@
include "mlir/Pass/PassBase.td"
-def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> {
+def LowerQuantOps : Pass<"lower-quant-ops"> {
let summary = "Lower 'quant.dcast' and 'quant.qcast' ops.";
let description = [{
Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
@@ -20,7 +20,8 @@ def QuantLowerQuantOpsPass : Pass<"lower-quant-ops"> {
The lowering process generates storage type casts in the form of
`quant.scast` ops to convert operands and results from quantized types to
the corresponding storage type, or vice versa.
- let dependentDialects = ["quant::QuantDialect"];
+ }];
+ let dependentDialects = ["::mlir::quant::QuantDialect"];
}
#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d2..df91ba51a0594 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect {
there will be tools to lower from the ML frameworks into TOSA.
}];
- let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"];
+ let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"];
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 75e62cda90d45..08de36fe21db0 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -136,7 +136,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
pdl_interp::PDLInterpDialect,
polynomial::PolynomialDialect,
ptr::PtrDialect,
- quant::QuantizationDialect,
+ quant::QuantDialect,
ROCDL::ROCDLDialect,
scf::SCFDialect,
shape::ShapeDialect,
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 1b9c1b193ace6..dd8b292a87344 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -35,6 +35,7 @@
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
memref::registerMemRefPasses();
mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
+ quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index b30d1de73288c..c94dbb5692fdb 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -13,7 +13,7 @@
using namespace mlir;
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
//===---------------------------------------------------------------------===//
// QuantizedType
diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt
index 037bba8dcb5c9..31167e6af908b 100644
--- a/mlir/lib/Dialect/Quant/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index 0f4b755367495..6a4ac310eb052 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -64,6 +64,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
};
} // namespace
-void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
+void quant::detail::addBytecodeInterface(QuantDialect *dialect) {
dialect->addInterfaces<QuantDialectBytecodeInterface>();
}
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
index 9e9cbf66d84d9..eef2b5bbefecc 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
@@ -15,12 +15,12 @@
#define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
namespace mlir::quant {
-class QuantizationDialect;
+class QuantDialect;
namespace detail {
/// Add the interfaces necessary for encoding the quantization dialect
/// components in bytecode.
-void addBytecodeInterface(QuantizationDialect *dialect);
+void addBytecodeInterface(QuantDialect *dialect);
} // namespace detail
} // namespace mlir::quant
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index fa9725c23d643..49c05aa7f98d3 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -26,7 +26,7 @@ using namespace mlir::quant::detail;
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
-void QuantizationDialect::initialize() {
+void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 15cde77e40afb..a4829d472ecad 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -25,7 +25,7 @@ unsigned QuantizedType::getFlags() const {
}
bool QuantizedType::classof(Type type) {
- return llvm::isa<QuantizationDialect>(type.getDialect());
+ return llvm::isa<QuantDialect>(type.getDialect());
}
LogicalResult
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index c882a616f397c..bf0f775146c1a 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -317,7 +317,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) {
}
/// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
+Type QuantDialect::parseType(DialectAsmParser &parser) const {
// All types start with an identifier that we switch on.
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling)))
@@ -419,7 +419,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type,
}
/// Print a type registered to this dialect.
-void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
+void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
printAnyQuantizedType(anyType, os);
else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..2daea7750cfe3
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRQuantTransforms
+ LowerQuantOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+
+ DEPENDS
+ MLIRQuantTransformsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRQuantDialect
+ MLIRPass
+ MLIRTransforms
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
new file mode 100644
index 0000000000000..72f89326f555b
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -0,0 +1,34 @@
+//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_LOWERQUANTOPS
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
>From 7e0a7b9edb810dc80eba435cbdabf2dd1e53d9b0 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 10 Jul 2024 19:51:20 -0400
Subject: [PATCH 03/18] Lowering for 'quant.qcast' with per-layer quantization
for ranked and unranked tensors
---
mlir/include/mlir/Dialect/Quant/IR/Quant.h | 8 +
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 46 ++--
.../mlir/Dialect/Quant/IR/QuantTypes.h | 4 +
.../mlir/Dialect/Quant/Transforms/Passes.h | 2 +
.../mlir/Dialect/Quant/Transforms/Passes.td | 17 +-
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 53 +++-
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 11 +
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 7 +-
.../Quant/Transforms/LowerQuantOps.cpp | 253 +++++++++++++++++-
9 files changed, 363 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
index a703612d6b489..c5ca88ec69795 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -21,6 +21,14 @@
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
+namespace mlir {
+namespace quant {
+
+class UniformQuantizedType;
+
+} // namespace quant
+} // namespace mlir
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index ba282d50328da..48e2496203ff0 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -28,6 +28,24 @@ class quant_Op<string mnemonic, list<Trait> traits> :
// Quantization casts
//===----------------------------------------------------------------------===//
+def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
+ let summary = "convert back from a quantized to quantizable (expressed) type operation";
+ let description = [{
+ A DequantizeCast op `dcast` represents the inverse of a `qcast`,
+ converting back from a quantized to quantizable (expressed) type.
+
+ Like `qcast`s, a `dcast` is allowed to have both its operand and result
+ as non quantized types. This facilitates transformations and marks edges
+ where the computation must be carried out in the expressed type.
+
+ Especially early in transformation, it is common to have `dcast`s on
+ all operands to ops that must operate with the expressed type (typically
+ math ops prior to lowering to target-specific, quantized kernels).
+ }];
+ let arguments = (ins quant_RealValueType:$input);
+ let results = (outs quant_RealValueType:$result);
+}
+
def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
let summary = "convert a quantizable type to a quantized type";
let description = [{
@@ -52,26 +70,18 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
it is legal to use a quantized representation (but is not known to be
acceptable).
}];
- let arguments = (ins quant_RealValueType:$arg);
- let results = (outs quant_RealValueType:$res);
-}
+ let arguments = (ins quant_RealValueType:$input);
+ let results = (outs quant_RealValueType:$result);
-def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
- let summary = "convert back from a quantized to quantizable (expressed) type operation";
- let description = [{
- A DequantizeCast op `dcast` represents the inverse of a `qcast`,
- converting back from a quantized to quantizable (expressed) type.
+ let extraClassDeclaration = [{
- Like `qcast`s, a `dcast` is allowed to have both its operand and result
- as non quantized types. This facilitates transformations and marks edges
- where the computation must be carried out in the expressed type.
+ /// Return the primitive (scalar or tensor element) type of the float input.
+ FloatType getFloatType();
- Especially early in transformation, it is common to have `dcast`s on
- all operands to ops that must operate with the expressed type (typically
- math ops prior to lowering to target-specific, quantized kernels).
+ /// Return the primitive (scalar or tensor element) type of the quantized
+ /// result.
+ quant::UniformQuantizedType getQuantizedType();
}];
- let arguments = (ins quant_RealValueType:$arg);
- let results = (outs quant_RealValueType:$res);
}
def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
@@ -95,8 +105,8 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
```
}];
- let arguments = (ins quant_RealOrStorageValueType:$arg);
- let results = (outs quant_RealOrStorageValueType:$res);
+ let arguments = (ins quant_RealOrStorageValueType:$input);
+ let results = (outs quant_RealOrStorageValueType:$result);
let hasFolder = 1;
}
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index c020e1b46ad4e..d4c9e7f9286a6 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -114,6 +114,10 @@ class QuantizedType : public Type {
/// The maximum value that storageType can take.
int64_t getStorageTypeMax() const;
+ /// Return whether the storage type has explicit min or max boundaries
+ /// different from the minimum and maximum representable values.
+ bool hasStorageTypeBounds() const;
+
/// Gets the integral bit width that the underlying storage type can exactly
/// represent. For integral storage types, this will just be their width.
unsigned getStorageTypeIntegralWidth() const;
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
index 0b7378651afa1..84be2a21b34ed 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
@@ -21,6 +21,8 @@ namespace quant {
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns);
+
} // namespace quant
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index e43062a98b1ea..19aa37f653c01 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -11,17 +11,24 @@
include "mlir/Pass/PassBase.td"
-def LowerQuantOps : Pass<"lower-quant-ops"> {
- let summary = "Lower 'quant.dcast' and 'quant.qcast' ops.";
+def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
+ let summary = "Lower quant.dcast and quant.qcast ops";
let description = [{
Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
into other core dialects.
The lowering process generates storage type casts in the form of
- `quant.scast` ops to convert operands and results from quantized types to
- the corresponding storage type, or vice versa.
+ `quant.scast` ops to act as an interface between the original quantized
+ types of operands and results and their corresponding storage types used in
+ the generated arithmetic computations.
}];
- let dependentDialects = ["::mlir::quant::QuantDialect"];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "linalg::LinalgDialect",
+ "quant::QuantDialect",
+ "scf::SCFDialect",
+ "tensor::TensorDialect"
+ ];
}
#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 49c05aa7f98d3..2eafa348f6906 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -20,12 +20,27 @@
#include "llvm/Support/MathExtras.h"
#include <numeric>
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
-
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
+
+namespace mlir {
+namespace quant {
+
+namespace {
+
+Type getPrimitiveType(Type ty) {
+ if (auto tensorType = dyn_cast<TensorType>(ty))
+ return tensorType.getElementType();
+ return ty;
+}
+
+} // namespace
+
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
@@ -33,17 +48,39 @@ void QuantDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
>();
- addBytecodeInterface(this);
+ detail::addBytecodeInterface(this);
}
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+FloatType QuantizeCastOp::getFloatType() {
+ return cast<FloatType>(getPrimitiveType(getInput().getType()));
+}
+
+UniformQuantizedType QuantizeCastOp::getQuantizedType() {
+ return cast<UniformQuantizedType>(getPrimitiveType(getResult().getType()));
+}
+
+
+//===----------------------------------------------------------------------===//
+// StorageCastOp
+//===----------------------------------------------------------------------===//
+
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
// value of x if the casts invert each other.
- auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
- if (!srcScastOp || srcScastOp.getArg().getType() != getType())
+ auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+ if (!srcScastOp || srcScastOp.getInput().getType() != getType())
return OpFoldResult();
- return srcScastOp.getArg();
+ return srcScastOp.getInput();
}
+} // namespace quant
+} // namespace mlir
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
+
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index a4829d472ecad..2038a86bec8d6 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -72,6 +72,17 @@ int64_t QuantizedType::getStorageTypeMax() const {
return static_cast<ImplType *>(impl)->storageTypeMax;
}
+bool QuantizedType::hasStorageTypeBounds() const {
+ unsigned int integralWidth = getStorageTypeIntegralWidth();
+ bool isSignedInteger = isSigned();
+ int64_t defaultIntegerMin =
+ getDefaultMinimumForInteger(isSignedInteger, integralWidth);
+ int64_t defaultIntegerMax =
+ getDefaultMaximumForInteger(isSignedInteger, integralWidth);
+ return defaultIntegerMin != getStorageTypeMin() ||
+ defaultIntegerMax != getStorageTypeMax();
+}
+
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index bf0f775146c1a..851763d8942e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -346,12 +346,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
}
// storageTypeMin and storageTypeMax if not default.
- int64_t defaultIntegerMin =
- QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
- int64_t defaultIntegerMax =
- QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
- if (defaultIntegerMin != type.getStorageTypeMin() ||
- defaultIntegerMax != type.getStorageTypeMax()) {
+ if (type.hasStorageTypeBounds()) {
out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
<< ">";
}
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 72f89326f555b..1067dc4f950d3 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -10,9 +10,16 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace quant {
@@ -22,13 +29,257 @@ namespace quant {
namespace {
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+public:
+ using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return success();
+ }
+};
+
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+// If 'containerType' is a tensor, return its element type. If it is a scalar,
+// return it as is.
+Type getScalarType(Type containerType) {
+ if (auto tensorType = dyn_cast<TensorType>(containerType))
+ return tensorType.getElementType();
+ return containerType;
+}
+
+// Return the shape of a container as a combination of attributes (static
+// dimensions) and values (dynamic dimensions). If 'container' is a scalar,
+// an empty list is returned. If 'container' is a tensor, its shape is returned.
+SmallVector<OpFoldResult> getContainerShape(OpBuilder &builder, Location loc,
+ Value container) {
+ if (isa<TensorType>(container.getType()))
+ return tensor::getMixedSizes(builder, loc, container);
+ return {};
+}
+
+// Clone the given 'containerType' with the new given 'elementType'. If
+// 'containerType' is a scalar type, there is nothing to clone, and
+// 'elementType' itself is returned. If 'constainerType' is a tensor, its
+// shape is cloned but the new element type is used.
+Type cloneContainerType(Type containerType, Type elementType) {
+ if (auto tensorType = dyn_cast<TensorType>(containerType))
+ return tensorType.clone(elementType);
+ return elementType;
+}
+
+// Get a scalar or tensor constant containing the value given in 'attr'.
+// If 'containerType' is a scalar, a scalar constant is returned. If
+// 'containerType' is a tensor, a tensor splat of shape 'containerShape' is
+// returned.
+Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
+ Type containerType,
+ ArrayRef<OpFoldResult> containerShape) {
+ // A statically shaped tensor can be created with 'arith.constant'
+ auto tensorType = dyn_cast<TensorType>(containerType);
+ if (tensorType && tensorType.hasStaticShape()) {
+ auto denseElementsAttr = DenseElementsAttr::get(tensorType, attr);
+ return builder.create<arith::ConstantOp>(loc, tensorType, denseElementsAttr);
+ }
+
+ // Scalar and dynamically shaped tensor containers need the scalar constant
+ // to be first materialized.
+ Value containerConstant =
+ builder.create<arith::ConstantOp>(loc, attr.getType(), attr);
+
+ // Create tensor splat if necessary
+ if (tensorType) {
+ containerConstant =
+ builder.create<tensor::SplatOp>(loc, containerConstant, containerShape);
+ }
+ return containerConstant;
+}
+
+// Calculate the size of an unranked tensor starting at dimension 'fromDim' up
+// to, but not including, dimension 'toDim'.
+Value getUnrankedTensorSizeRange(OpBuilder &builder, Location loc, Value input,
+ Value fromDim, Value toDim, Value one) {
+ auto loop = builder.create<scf::ForOp>(
+ loc,
+ fromDim, // lowerBound
+ toDim, // upperBound
+ one, // step
+ one, // iterArgs
+ [&](OpBuilder &builder, Location loc, Value index, ValueRange args) {
+ Value size = builder.create<tensor::DimOp>(loc, input, index);
+ Value totalSize = builder.create<arith::MulIOp>(loc, args.front(), size);
+ builder.create<scf::YieldOp>(loc, totalSize);
+ });
+ return loop.getResult(0);
+}
+
+// Obtain the shape of an unranked tensor. This function returns a 1D tensor of
+// size 'rank' and element type 'index'.
+Value getUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
+ Value rank) {
+ auto shapeType =
+ RankedTensorType::get({ShapedType::kDynamic}, builder.getIndexType());
+ auto shape = builder.create<tensor::GenerateOp>(
+ loc,
+ shapeType,
+ rank,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ Value size = builder.create<tensor::DimOp>(loc, input, args.front());
+ builder.create<tensor::YieldOp>(loc, size);
+ });
+ return shape;
+}
+
+class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+
+ Value convertPerLayerScalarOrRanked(
+ OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedType quantizedType) const {
+
+ auto inputType = input.getType();
+ auto expressedType = cast<FloatType>(quantizedType.getExpressedType());
+ auto storageType = cast<IntegerType>(quantizedType.getStorageType());
+ auto storageContainerType = cloneContainerType(inputType, storageType);
+
+ auto inputShape = getContainerShape(builder, loc, input);
+
+ // Scale and zero point scalars
+ auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale());
+ auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape);
+ auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint());
+ auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape);
+
+ auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+ auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+
+ Value storedValue;
+ if (quantizedType.isSigned()) {
+ storedValue = builder.create<arith::FPToSIOp>(
+ loc, storageContainerType, storedValueAsExpressedType);
+ } else {
+ storedValue = builder.create<arith::FPToUIOp>(
+ loc, storageContainerType, storedValueAsExpressedType);
+ }
+
+ // Clamp stored value if needed
+ if (quantizedType.hasStorageTypeBounds()) {
+ auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
+ auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
+ auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
+ auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
+ if (quantizedType.isSigned()) {
+ storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
+ storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
+ } else {
+ storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
+ storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
+ }
+ }
+
+ return storedValue;
+ }
+
+ Value convertPerLayerUnranked(
+ OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedType quantizedType) const {
+ auto rank = builder.create<tensor::RankOp>(loc, input);
+ auto inputShape = getUnrankedTensorShape(builder, loc, input, rank);
+ auto inputType = cast<UnrankedTensorType>(input.getType());
+
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto inputSize = getUnrankedTensorSizeRange(builder, loc, input, zero, rank, one);
+
+ // Compute collapsed input shape as a 1D 1-sized index tensor
+ auto collapsedInputShapeType = RankedTensorType::get({1}, builder.getIndexType());
+ auto collapsedInputShape = builder.create<tensor::FromElementsOp>(
+ loc, collapsedInputShapeType, inputSize);
+
+ // Reshape input tensor into 1D
+ auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic},
+ inputType.getElementType());
+ auto collapsedInput = builder.create<tensor::ReshapeOp>(
+ loc, collapsedInputType, input, collapsedInputShape);
+
+ // Now we know how to convert a ranked tensor
+ auto collapsedStoredValue = convertPerLayerScalarOrRanked(
+ builder, loc, collapsedInput, quantizedType);
+
+ // Expand stored value back to the original shape
+ auto expandedStoredValueType =
+ UnrankedTensorType::get(quantizedType.getStorageType());
+ auto expandedStoredValue = builder.create<tensor::ReshapeOp>(
+ loc, expandedStoredValueType, collapsedStoredValue, inputShape);
+ return expandedStoredValue;
+ }
+
+public:
+ using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto input = op.getInput();
+ auto resultScalarType = getScalarType(op.getResult().getType());
+
+ // Per-layer vs per-channel quantization
+ Value storedValue;
+ if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
+ storedValue = isa<UnrankedTensorType>(input.getType()) ?
+ convertPerLayerUnranked(rewriter, loc, input, quantizedType) :
+ convertPerLayerScalarOrRanked(rewriter, loc, input, quantizedType);
+ } else if (auto quantizedType = dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
+ // FIXM
+ } else {
+ llvm_unreachable("unexpected quantized type");
+ }
+
+ // Cast stored value to result quantized value
+ rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
+ op, op.getResult().getType(), storedValue);
+ return success();
+ }
+};
+
struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
void runOnOperation() override {
- Operation *parentOp = getOperation();
+ RewritePatternSet patterns(&getContext());
+ populateLowerQuantOpsPatterns(patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalOp<quant::StorageCastOp>();
+ target.addIllegalDialect<quant::QuantDialect>();
+ target.addLegalDialect<
+ arith::ArithDialect,
+ linalg::LinalgDialect,
+ scf::SCFDialect,
+ tensor::TensorDialect
+ >();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
}
};
} // namespace
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
+ patterns.add<
+ DequantizeCastOpConversion,
+ QuantizeCastOpConversion
+ >(patterns.getContext());
+}
+
} // namespace quant
} // namespace mlir
>From 8f8f6de514a07452f60442464fdd8a77a166a9f4 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 11 Jul 2024 18:01:31 -0400
Subject: [PATCH 04/18] Custom formats for 'quant' ops and use of 'shape'
dialect ops for 'lower-quant-ops'
---
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 13 +----
.../mlir/Dialect/Quant/Transforms/Passes.td | 2 +-
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 13 -----
.../Quant/Transforms/LowerQuantOps.cpp | 57 ++++---------------
4 files changed, 15 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 48e2496203ff0..7a6d270dbb6e9 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -44,6 +44,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
}];
let arguments = (ins quant_RealValueType:$input);
let results = (outs quant_RealValueType:$result);
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
}
def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
@@ -72,16 +73,7 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
}];
let arguments = (ins quant_RealValueType:$input);
let results = (outs quant_RealValueType:$result);
-
- let extraClassDeclaration = [{
-
- /// Return the primitive (scalar or tensor element) type of the float input.
- FloatType getFloatType();
-
- /// Return the primitive (scalar or tensor element) type of the quantized
- /// result.
- quant::UniformQuantizedType getQuantizedType();
- }];
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
}
def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
@@ -107,6 +99,7 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
}];
let arguments = (ins quant_RealOrStorageValueType:$input);
let results = (outs quant_RealOrStorageValueType:$result);
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
let hasFolder = 1;
}
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index 19aa37f653c01..56e10688b0c98 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -26,7 +26,7 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
"arith::ArithDialect",
"linalg::LinalgDialect",
"quant::QuantDialect",
- "scf::SCFDialect",
+ "shape::ShapeDialect",
"tensor::TensorDialect"
];
}
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 2eafa348f6906..e04ca7eb7e715 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -52,19 +52,6 @@ void QuantDialect::initialize() {
}
-//===----------------------------------------------------------------------===//
-// QuantizeCastOp
-//===----------------------------------------------------------------------===//
-
-FloatType QuantizeCastOp::getFloatType() {
- return cast<FloatType>(getPrimitiveType(getInput().getType()));
-}
-
-UniformQuantizedType QuantizeCastOp::getQuantizedType() {
- return cast<UniformQuantizedType>(getPrimitiveType(getResult().getType()));
-}
-
-
//===----------------------------------------------------------------------===//
// StorageCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1067dc4f950d3..1daafdd715155 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -16,7 +16,7 @@
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -104,41 +104,6 @@ Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
return containerConstant;
}
-// Calculate the size of an unranked tensor starting at dimension 'fromDim' up
-// to, but not including, dimension 'toDim'.
-Value getUnrankedTensorSizeRange(OpBuilder &builder, Location loc, Value input,
- Value fromDim, Value toDim, Value one) {
- auto loop = builder.create<scf::ForOp>(
- loc,
- fromDim, // lowerBound
- toDim, // upperBound
- one, // step
- one, // iterArgs
- [&](OpBuilder &builder, Location loc, Value index, ValueRange args) {
- Value size = builder.create<tensor::DimOp>(loc, input, index);
- Value totalSize = builder.create<arith::MulIOp>(loc, args.front(), size);
- builder.create<scf::YieldOp>(loc, totalSize);
- });
- return loop.getResult(0);
-}
-
-// Obtain the shape of an unranked tensor. This function returns a 1D tensor of
-// size 'rank' and element type 'index'.
-Value getUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
- Value rank) {
- auto shapeType =
- RankedTensorType::get({ShapedType::kDynamic}, builder.getIndexType());
- auto shape = builder.create<tensor::GenerateOp>(
- loc,
- shapeType,
- rank,
- [&](OpBuilder &builder, Location loc, ValueRange args) {
- Value size = builder.create<tensor::DimOp>(loc, input, args.front());
- builder.create<tensor::YieldOp>(loc, size);
- });
- return shape;
-}
-
class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
Value convertPerLayerScalarOrRanked(
@@ -191,18 +156,18 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
Value convertPerLayerUnranked(
OpBuilder &builder, Location loc, Value input,
UniformQuantizedType quantizedType) const {
- auto rank = builder.create<tensor::RankOp>(loc, input);
- auto inputShape = getUnrankedTensorShape(builder, loc, input, rank);
+ auto *context = builder.getContext();
auto inputType = cast<UnrankedTensorType>(input.getType());
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
- auto inputSize = getUnrankedTensorSizeRange(builder, loc, input, zero, rank, one);
+ auto shapeType = shape::getExtentTensorType(context);
+ auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+ Value inputSize = builder.create<shape::NumElementsOp>(
+ loc, builder.getIndexType(), inputShape);
- // Compute collapsed input shape as a 1D 1-sized index tensor
- auto collapsedInputShapeType = RankedTensorType::get({1}, builder.getIndexType());
+ // Turn input size into 1D tensor
+ auto collapsedShapeType = shape::getExtentTensorType(context, 1);
auto collapsedInputShape = builder.create<tensor::FromElementsOp>(
- loc, collapsedInputShapeType, inputSize);
+ loc, collapsedShapeType, inputSize);
// Reshape input tensor into 1D
auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic},
@@ -210,7 +175,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
auto collapsedInput = builder.create<tensor::ReshapeOp>(
loc, collapsedInputType, input, collapsedInputShape);
- // Now we know how to convert a ranked tensor
+ // We now know how to deal with a 1D ranked input
auto collapsedStoredValue = convertPerLayerScalarOrRanked(
builder, loc, collapsedInput, quantizedType);
@@ -262,7 +227,7 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
target.addLegalDialect<
arith::ArithDialect,
linalg::LinalgDialect,
- scf::SCFDialect,
+ shape::ShapeDialect,
tensor::TensorDialect
>();
>From ba88a9c38a212637c3446aca204d7d8f767fad33 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 15 Jul 2024 18:08:50 -0400
Subject: [PATCH 05/18] Progress in per-channel quantization
---
.../Quant/Transforms/LowerQuantOps.cpp | 237 ++++++++++++------
1 file changed, 167 insertions(+), 70 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1daafdd715155..f5f4de807c8f0 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -104,87 +104,186 @@ Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
return containerConstant;
}
+std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
+ Value input) {
+ // Get unranked input shape and total size
+ auto *context = builder.getContext();
+ auto shapeType = shape::getExtentTensorType(context);
+ auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+ Value inputSize = builder.create<shape::NumElementsOp>(
+ loc, builder.getIndexType(), inputShape);
+
+ // Turn input size into 1D tensor
+ auto flatShapeType = shape::getExtentTensorType(context, 1);
+ auto flatInputShape = builder.create<tensor::FromElementsOp>(
+ loc, flatShapeType, inputSize);
+
+ // Reshape input tensor into 1D
+ auto inputType = cast<UnrankedTensorType>(input.getType());
+ auto flatInputType =
+ RankedTensorType::get({ShapedType::kDynamic}, inputType.getElementType());
+ auto flatInput = builder.create<tensor::ReshapeOp>(
+ loc, flatInputType, input, flatInputShape);
+ return std::make_pair(flatInput, inputShape);
+}
+
+Value restoreUnrankedTensor(OpBuilder &builder, Location loc, Value input,
+ Value shape) {
+ auto inputType = cast<TensorType>(input.getType());
+ auto elementType = inputType.getElementType();
+ auto unrankedType = UnrankedTensorType::get(elementType);
+ return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, shape);
+}
+
+Value materializeScales(OpBuilder &builder, Location loc,
+ UniformQuantizedPerAxisType quantizedType) {
+ auto scales = quantizedType.getScales();
+ auto expressedType = quantizedType.getExpressedType();
+ auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
+ return builder.getFloatAttr(expressedType, scale);
+ });
+ auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+ auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
+ return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+}
+
+Value materializeZeroPoints(OpBuilder &builder, Location loc,
+ UniformQuantizedPerAxisType quantizedType) {
+ auto zeroPoints = quantizedType.getZeroPoints();
+ auto expressedType = quantizedType.getExpressedType();
+ auto zeroPointAttrs = llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
+ return builder.getFloatAttr(expressedType, static_cast<double>(zeroPoint));
+ });
+ auto tensorType = RankedTensorType::get({(int64_t) zeroPoints.size()}, expressedType);
+ auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
+ return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+}
+
+Value quantizeValue(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ auto inputType = input.getType();
+ auto storageType = cast<IntegerType>(quantizedType.getStorageType());
+ auto storageContainerType = cloneContainerType(inputType, storageType);
+
+ auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+ auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+
+ Value storedValue;
+ if (quantizedType.isSigned()) {
+ storedValue = builder.create<arith::FPToSIOp>(
+ loc, storageContainerType, storedValueAsExpressedType);
+ } else {
+ storedValue = builder.create<arith::FPToUIOp>(
+ loc, storageContainerType, storedValueAsExpressedType);
+ }
+
+ // Clamp stored value if needed
+ if (quantizedType.hasStorageTypeBounds()) {
+ auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
+ auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
+ auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
+ auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
+ if (quantizedType.isSigned()) {
+ storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
+ storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
+ } else {
+ storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
+ storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
+ }
+ }
+
+ return storedValue;
+}
+
class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
- Value convertPerLayerScalarOrRanked(
- OpBuilder &builder, Location loc, Value input,
- UniformQuantizedType quantizedType) const {
+ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedType quantizedType) const {
auto inputType = input.getType();
auto expressedType = cast<FloatType>(quantizedType.getExpressedType());
- auto storageType = cast<IntegerType>(quantizedType.getStorageType());
- auto storageContainerType = cloneContainerType(inputType, storageType);
+ // Create scale and zero point constants
auto inputShape = getContainerShape(builder, loc, input);
-
- // Scale and zero point scalars
auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale());
auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape);
auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint());
auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape);
- auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
- auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+ return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+ quantizedType);
+ }
- Value storedValue;
- if (quantizedType.isSigned()) {
- storedValue = builder.create<arith::FPToSIOp>(
- loc, storageContainerType, storedValueAsExpressedType);
- } else {
- storedValue = builder.create<arith::FPToUIOp>(
- loc, storageContainerType, storedValueAsExpressedType);
- }
+ Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedType quantizedType) const {
+ // Flatten input if unranked
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ Value shape;
+ if (isUnranked)
+ std::tie(input, shape) = flattenUnrankedTensor(builder, loc, input);
- // Clamp stored value if needed
- if (quantizedType.hasStorageTypeBounds()) {
- auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
- auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
- auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
- auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
- if (quantizedType.isSigned()) {
- storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
- storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
- } else {
- storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
- storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
- }
- }
+ // Process ranked tensor
+ auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
- return storedValue;
+ // Restore original shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensor(builder, loc, result, shape);
+
+ return result;
}
-
- Value convertPerLayerUnranked(
- OpBuilder &builder, Location loc, Value input,
- UniformQuantizedType quantizedType) const {
+
+ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedPerAxisType quantizedType,
+ int32_t channelAxis) const {
auto *context = builder.getContext();
- auto inputType = cast<UnrankedTensorType>(input.getType());
-
- auto shapeType = shape::getExtentTensorType(context);
- auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
- Value inputSize = builder.create<shape::NumElementsOp>(
- loc, builder.getIndexType(), inputShape);
-
- // Turn input size into 1D tensor
- auto collapsedShapeType = shape::getExtentTensorType(context, 1);
- auto collapsedInputShape = builder.create<tensor::FromElementsOp>(
- loc, collapsedShapeType, inputSize);
-
- // Reshape input tensor into 1D
- auto collapsedInputType = RankedTensorType::get({ShapedType::kDynamic},
- inputType.getElementType());
- auto collapsedInput = builder.create<tensor::ReshapeOp>(
- loc, collapsedInputType, input, collapsedInputShape);
-
- // We now know how to deal with a 1D ranked input
- auto collapsedStoredValue = convertPerLayerScalarOrRanked(
- builder, loc, collapsedInput, quantizedType);
-
- // Expand stored value back to the original shape
- auto expandedStoredValueType =
- UnrankedTensorType::get(quantizedType.getStorageType());
- auto expandedStoredValue = builder.create<tensor::ReshapeOp>(
- loc, expandedStoredValueType, collapsedStoredValue, inputShape);
- return expandedStoredValue;
+
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto inputRank = inputType.getRank();
+
+ auto scales = materializeScales(builder, loc, quantizedType);
+ auto zeroPoints = materializeZeroPoints(builder, loc, quantizedType);
+
+ auto storageType = quantizedType.getStorageType();
+ auto initShape = tensor::getMixedSizes(builder, loc, input);
+ Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
+ inputRank, utils::IteratorType::parallel);
+ auto channelAxisAffineMap = AffineMap::get(
+ inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+ SmallVector<AffineMap> indexingMaps{
+ builder.getMultiDimIdentityMap(inputRank),
+ channelAxisAffineMap,
+ channelAxisAffineMap,
+ builder.getMultiDimIdentityMap(inputRank)
+ };
+ auto storedValue = builder.create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps,
+ iteratorTypes,
+ [&](OpBuilder& builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto expressedValue = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto storedValue = quantizeValue(builder, loc, expressedValue, {},
+ scale, zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, storedValue);
+ })
+ .getResult(0);
+
+ return storedValue;
+ }
+
+ Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedPerAxisType quantizedType) const {
+ return convertPerChannelRanked(builder, loc, input, quantizedType, quantizedType.getQuantizedDimension());
}
public:
@@ -197,18 +296,16 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
auto input = op.getInput();
auto resultScalarType = getScalarType(op.getResult().getType());
- // Per-layer vs per-channel quantization
+ // Flatten unranked tensor input
Value storedValue;
if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
- storedValue = isa<UnrankedTensorType>(input.getType()) ?
- convertPerLayerUnranked(rewriter, loc, input, quantizedType) :
- convertPerLayerScalarOrRanked(rewriter, loc, input, quantizedType);
+ storedValue = convertPerLayer(rewriter, loc, input, quantizedType);
} else if (auto quantizedType = dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
- // FIXM
+ storedValue = convertPerChannel(rewriter, loc, input, quantizedType);
} else {
llvm_unreachable("unexpected quantized type");
}
-
+
// Cast stored value to result quantized value
rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
op, op.getResult().getType(), storedValue);
>From 6b9caccad131a0d2e6cf22419b0486cf01f1b28b Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 16 Jul 2024 10:15:24 -0400
Subject: [PATCH 06/18] Support for unranked tensor in per-channel quantization
---
.../Quant/Transforms/LowerQuantOps.cpp | 72 +++++++++++++++++--
1 file changed, 66 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index f5f4de807c8f0..49ef8ea354070 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -120,10 +120,52 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
// Reshape input tensor into 1D
auto inputType = cast<UnrankedTensorType>(input.getType());
+ auto elementType = inputType.getElementType();
auto flatInputType =
- RankedTensorType::get({ShapedType::kDynamic}, inputType.getElementType());
+ RankedTensorType::get({ShapedType::kDynamic}, elementType);
+ auto flatInput = builder.create<tensor::ReshapeOp>(
+ loc, flatInputType, input, flatInputShape);
+ return std::make_pair(flatInput, inputShape);
+}
+
+std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
+ Location loc,
+ Value input,
+ int64_t axis) {
+ // Get full tensor shape
+ auto *context = builder.getContext();
+ auto indexType = builder.getIndexType();
+ auto shapeType = shape::getExtentTensorType(context);
+ auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+
+ // Get shape and sizes on left and right of axis
+ auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
+ auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
+ auto shapeLeft = builder.create<shape::SplitAtOp>(
+ loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
+ .getResult(0);
+ auto sizeLeft = builder.create<shape::NumElementsOp>(
+ loc, indexType, shapeLeft);
+ auto shapeRight = builder.create<shape::SplitAtOp>(
+ loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
+ .getResult(1);
+ auto sizeRight = builder.create<shape::NumElementsOp>(
+ loc, indexType, shapeRight);
+ Value axisSize = builder.create<tensor::DimOp>(loc, input, axisValue);
+
+ // Compute flat input shape as a 3-element 1D tensor
+ auto flatShapeType = shape::getExtentTensorType(context, 3);
+ auto flatInputShape = builder.create<tensor::FromElementsOp>(
+ loc, flatShapeType, ValueRange{sizeLeft, axisSize, sizeRight});
+
+ // Reshape input to 3D tensor
+ auto inputType = cast<UnrankedTensorType>(input.getType());
+ auto elementType = inputType.getElementType();
+ SmallVector<int64_t> flatInputDims(3, ShapedType::kDynamic);
+ auto flatInputType = RankedTensorType::get(flatInputDims, elementType);
auto flatInput = builder.create<tensor::ReshapeOp>(
loc, flatInputType, input, flatInputShape);
+
return std::make_pair(flatInput, inputShape);
}
@@ -219,23 +261,23 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
UniformQuantizedType quantizedType) const {
// Flatten input if unranked
bool isUnranked = isa<UnrankedTensorType>(input.getType());
- Value shape;
+ Value inputShape;
if (isUnranked)
- std::tie(input, shape) = flattenUnrankedTensor(builder, loc, input);
+ std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
// Process ranked tensor
auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
// Restore original shape if unranked
if (isUnranked)
- result = restoreUnrankedTensor(builder, loc, result, shape);
+ result = restoreUnrankedTensor(builder, loc, result, inputShape);
return result;
}
Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
UniformQuantizedPerAxisType quantizedType,
- int32_t channelAxis) const {
+ int64_t channelAxis) const {
auto *context = builder.getContext();
auto inputType = cast<RankedTensorType>(input.getType());
@@ -283,7 +325,25 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
UniformQuantizedPerAxisType quantizedType) const {
- return convertPerChannelRanked(builder, loc, input, quantizedType, quantizedType.getQuantizedDimension());
+ // Flatten unranked tensor if necessary
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ int64_t channelAxis = quantizedType.getQuantizedDimension();
+ Value inputShape;
+ if (isUnranked) {
+ std::tie(input, inputShape) =
+ flattenUnrankedTensorAroundAxis(builder, loc, input, channelAxis);
+ channelAxis = 1;
+ }
+
+ // Work on a ranked tensor
+ auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
+ channelAxis);
+
+ // Restore original tensor shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensor(builder, loc, result, inputShape);
+
+ return result;
}
public:
>From b2fee689c74012651662d4edc8c00988aea5344b Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 16 Jul 2024 15:24:15 -0400
Subject: [PATCH 07/18] Refactored 'quant.qcast' lowering. Ready to begin
'quant.dcast'
---
.../Quant/Transforms/LowerQuantOps.cpp | 270 ++++++++++--------
1 file changed, 158 insertions(+), 112 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 49ef8ea354070..68a4328128292 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -49,59 +49,50 @@ class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeC
// QuantizeCastOp
//===----------------------------------------------------------------------===//
-// If 'containerType' is a tensor, return its element type. If it is a scalar,
+// If 'inputType' is a tensor, return its element type. If it is a scalar,
// return it as is.
-Type getScalarType(Type containerType) {
- if (auto tensorType = dyn_cast<TensorType>(containerType))
+Type getScalarType(Type inputType) {
+ if (auto tensorType = dyn_cast<TensorType>(inputType))
return tensorType.getElementType();
- return containerType;
+ return inputType;
}
-// Return the shape of a container as a combination of attributes (static
-// dimensions) and values (dynamic dimensions). If 'container' is a scalar,
-// an empty list is returned. If 'container' is a tensor, its shape is returned.
-SmallVector<OpFoldResult> getContainerShape(OpBuilder &builder, Location loc,
- Value container) {
- if (isa<TensorType>(container.getType()))
- return tensor::getMixedSizes(builder, loc, container);
+// Return the shape of an input value as a list of attributes (static dimensions)
+// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
+// returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult>
+getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+ if (isa<TensorType>(input.getType()))
+ return tensor::getMixedSizes(builder, loc, input);
return {};
}
-// Clone the given 'containerType' with the new given 'elementType'. If
-// 'containerType' is a scalar type, there is nothing to clone, and
-// 'elementType' itself is returned. If 'constainerType' is a tensor, its
-// shape is cloned but the new element type is used.
-Type cloneContainerType(Type containerType, Type elementType) {
- if (auto tensorType = dyn_cast<TensorType>(containerType))
+// If 'referenceType' is a scalar, return 'elementType' as is. If
+// 'referenceType' is a tensor, return another tensor with the same shape and
+// elements of type 'elementType'.
+Type getScalarOrTensorType(Type elementType, Type referenceType) {
+ if (auto tensorType = dyn_cast<TensorType>(referenceType))
return tensorType.clone(elementType);
return elementType;
}
-// Get a scalar or tensor constant containing the value given in 'attr'.
-// If 'containerType' is a scalar, a scalar constant is returned. If
-// 'containerType' is a tensor, a tensor splat of shape 'containerShape' is
-// returned.
-Value getContainerConstant(OpBuilder &builder, Location loc, TypedAttr attr,
- Type containerType,
- ArrayRef<OpFoldResult> containerShape) {
- // A statically shaped tensor can be created with 'arith.constant'
- auto tensorType = dyn_cast<TensorType>(containerType);
- if (tensorType && tensorType.hasStaticShape()) {
- auto denseElementsAttr = DenseElementsAttr::get(tensorType, attr);
- return builder.create<arith::ConstantOp>(loc, tensorType, denseElementsAttr);
+// Return a constant with the given value. If 'referenceType' is a tensor, a
+// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
+// scalar, 'referenceShape' is ignored and a scalar constant is returned.
+Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
+ Type referenceType,
+ ArrayRef<OpFoldResult> referenceShape) {
+ // If the result type is a scalar, return the unmodified scalar constant.
+ auto tensorType = dyn_cast<TensorType>(referenceType);
+ if (!tensorType) {
+ assert(referenceShape.empty());
+ return scalar;
}
- // Scalar and dynamically shaped tensor containers need the scalar constant
- // to be first materialized.
- Value containerConstant =
- builder.create<arith::ConstantOp>(loc, attr.getType(), attr);
-
- // Create tensor splat if necessary
- if (tensorType) {
- containerConstant =
- builder.create<tensor::SplatOp>(loc, containerConstant, containerShape);
- }
- return containerConstant;
+ // Create tensor splat
+ auto tensorConstant =
+ builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
+ return tensorConstant;
}
std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
@@ -131,7 +122,8 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
Location loc,
Value input,
- int64_t axis) {
+ int64_t axis,
+ int64_t axisSize) {
// Get full tensor shape
auto *context = builder.getContext();
auto indexType = builder.getIndexType();
@@ -151,34 +143,34 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
.getResult(1);
auto sizeRight = builder.create<shape::NumElementsOp>(
loc, indexType, shapeRight);
- Value axisSize = builder.create<tensor::DimOp>(loc, input, axisValue);
// Compute flat input shape as a 3-element 1D tensor
+ auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
auto flatShapeType = shape::getExtentTensorType(context, 3);
auto flatInputShape = builder.create<tensor::FromElementsOp>(
- loc, flatShapeType, ValueRange{sizeLeft, axisSize, sizeRight});
+ loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
// Reshape input to 3D tensor
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
- SmallVector<int64_t> flatInputDims(3, ShapedType::kDynamic);
- auto flatInputType = RankedTensorType::get(flatInputDims, elementType);
+ auto flatInputType = RankedTensorType::get(
+ {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
auto flatInput = builder.create<tensor::ReshapeOp>(
loc, flatInputType, input, flatInputShape);
return std::make_pair(flatInput, inputShape);
}
-Value restoreUnrankedTensor(OpBuilder &builder, Location loc, Value input,
- Value shape) {
- auto inputType = cast<TensorType>(input.getType());
+Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
+ Value inputShape) {
+ auto inputType = cast<RankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto unrankedType = UnrankedTensorType::get(elementType);
- return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, shape);
+ return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
}
-Value materializeScales(OpBuilder &builder, Location loc,
- UniformQuantizedPerAxisType quantizedType) {
+Value materializePerChannelScales(OpBuilder &builder, Location loc,
+ UniformQuantizedPerAxisType quantizedType) {
auto scales = quantizedType.getScales();
auto expressedType = quantizedType.getExpressedType();
auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
@@ -189,52 +181,100 @@ Value materializeScales(OpBuilder &builder, Location loc,
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
}
-Value materializeZeroPoints(OpBuilder &builder, Location loc,
+Value materializePerChannelZeroPoints(OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
- auto expressedType = quantizedType.getExpressedType();
- auto zeroPointAttrs = llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
- return builder.getFloatAttr(expressedType, static_cast<double>(zeroPoint));
- });
- auto tensorType = RankedTensorType::get({(int64_t) zeroPoints.size()}, expressedType);
+ auto storageType = quantizedType.getStorageType();
+ auto zeroPointAttrs = llvm::map_to_vector(
+ zeroPoints,
+ [&](int64_t zeroPoint) -> Attribute {
+ return builder.getIntegerAttr(storageType, zeroPoint);
+ });
+ auto tensorType =
+ RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
}
-Value quantizeValue(OpBuilder &builder, Location loc, Value input,
- ArrayRef<OpFoldResult> inputShape, Value scale,
- Value zeroPoint, QuantizedType quantizedType) {
- auto inputType = input.getType();
- auto storageType = cast<IntegerType>(quantizedType.getStorageType());
- auto storageContainerType = cloneContainerType(inputType, storageType);
-
- auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
- auto storedValueAsExpressedType = builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape,
+ QuantizedType quantizedType) {
+ // If quantized type does not narrow down the storage type range, there is
+ // nothing to do.
+ if (!quantizedType.hasStorageTypeBounds())
+ return input;
- Value storedValue;
+ // Materialize bounds
+ auto inputType = input.getType();
+ auto storageType = quantizedType.getStorageType();
+ auto storageMinScalar = builder.create<arith::ConstantIntOp>(
+ loc, quantizedType.getStorageTypeMin(), storageType);
+ auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
+ loc, quantizedType.getStorageTypeMax(), storageType);
+ auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
+ inputType, inputShape);
+ auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
+ inputType, inputShape);
+
+ // Clamp
if (quantizedType.isSigned()) {
- storedValue = builder.create<arith::FPToSIOp>(
- loc, storageContainerType, storedValueAsExpressedType);
+ input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
+ input = builder.create<arith::MinSIOp>(loc, input, storageMax);
} else {
- storedValue = builder.create<arith::FPToUIOp>(
- loc, storageContainerType, storedValueAsExpressedType);
+ input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
+ input = builder.create<arith::MinUIOp>(loc, input, storageMax);
}
+ return input;
+}
- // Clamp stored value if needed
- if (quantizedType.hasStorageTypeBounds()) {
- auto storageMinAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMin());
- auto storageMaxAttr = builder.getIntegerAttr(storageType, quantizedType.getStorageTypeMax());
- auto storageMin = getContainerConstant(builder, loc, storageMinAttr, inputType, inputShape);
- auto storageMax = getContainerConstant(builder, loc, storageMaxAttr, inputType, inputShape);
- if (quantizedType.isSigned()) {
- storedValue = builder.create<arith::MaxSIOp>(loc, storedValue, storageMin);
- storedValue = builder.create<arith::MinSIOp>(loc, storedValue, storageMax);
- } else {
- storedValue = builder.create<arith::MaxUIOp>(loc, storedValue, storageMin);
- storedValue = builder.create<arith::MinUIOp>(loc, storedValue, storageMax);
- }
- }
+Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
+ Type resultType, bool isSigned) {
+ if (isSigned)
+ return builder.create<arith::FPToSIOp>(loc, resultType, input);
+ return builder.create<arith::FPToUIOp>(loc, resultType, input);
+}
+
+Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
+ Type resultType, bool isSigned) {
+ if (isSigned)
+ return builder.create<arith::SIToFPOp>(loc, resultType, input);
+ return builder.create<arith::UIToFPOp>(loc, resultType, input);
+}
+// Quantize a floating-point input using the given scale, input shape, and
+// storage type bounds in the given quantized type.
+Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ // Convert scale and zero point to tensors if necessary
+ auto inputType = input.getType();
+ scale = getScalarOrTensorConstant(
+ builder, loc, scale, inputType, inputShape);
+ zeroPoint = getScalarOrTensorConstant(
+ builder, loc, zeroPoint, inputType, inputShape);
+
+ // Convert zero point from storage to expressed type
+ auto expressedScalarOrTensorType =
+ getScalarOrTensorType(quantizedType.getExpressedType(), inputType);
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+ expressedScalarOrTensorType,
+ quantizedType.isSigned());
+
+ // Scale input and add zero point
+ auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+ auto storedValueAsExpressedType =
+ builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+
+ // Convert to storage type
+ auto storageScalarOrTensorType =
+ getScalarOrTensorType(quantizedType.getStorageType(), inputType);
+ auto storedValue = convertFloatToInteger(
+ builder, loc, storedValueAsExpressedType, storageScalarOrTensorType,
+ quantizedType.isSigned());
+
+ // Clamp stored value it if the storage type is bound
+ storedValue =
+ clampScalarOrTensor(builder, loc, storedValue, inputShape, quantizedType);
return storedValue;
}
@@ -243,18 +283,21 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
UniformQuantizedType quantizedType) const {
- auto inputType = input.getType();
- auto expressedType = cast<FloatType>(quantizedType.getExpressedType());
-
// Create scale and zero point constants
- auto inputShape = getContainerShape(builder, loc, input);
- auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale());
- auto scale = getContainerConstant(builder, loc, scaleAttr, inputType, inputShape);
- auto zeroPointAttr = builder.getFloatAttr(expressedType, quantizedType.getZeroPoint());
- auto zeroPoint = getContainerConstant(builder, loc, zeroPointAttr, inputType, inputShape);
-
- return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
- quantizedType);
+ auto expressedType = quantizedType.getExpressedType();
+ auto storageType = quantizedType.getStorageType();
+ auto scaleAttr =
+ builder.getFloatAttr(expressedType, quantizedType.getScale());
+ auto scale =
+ builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+ auto zeroPointAttr =
+ builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+ auto zeroPoint =
+ builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+ auto inputShape = getScalarOrTensorShape(builder, loc, input);
+ return quantizeScalarOrTensor(builder, loc, input, inputShape, scale,
+ zeroPoint, quantizedType);
}
Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
@@ -270,7 +313,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
// Restore original shape if unranked
if (isUnranked)
- result = restoreUnrankedTensor(builder, loc, result, inputShape);
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
return result;
}
@@ -283,8 +326,9 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
- auto scales = materializeScales(builder, loc, quantizedType);
- auto zeroPoints = materializeZeroPoints(builder, loc, quantizedType);
+ auto scales = materializePerChannelScales(builder, loc, quantizedType);
+ auto zeroPoints =
+ materializePerChannelZeroPoints(builder, loc, quantizedType);
auto storageType = quantizedType.getStorageType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
@@ -313,10 +357,10 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
auto scale = args[1];
auto zeroPoint = args[2];
- auto storedValue = quantizeValue(builder, loc, expressedValue, {},
- scale, zeroPoint, quantizedType);
+ auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
+ scale, zeroPoint, quantizedType);
- builder.create<linalg::YieldOp>(loc, storedValue);
+ builder.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
@@ -325,13 +369,14 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
UniformQuantizedPerAxisType quantizedType) const {
- // Flatten unranked tensor if necessary
+ // Flatten unranked tensor into a 3D ranked tensor if necessary
bool isUnranked = isa<UnrankedTensorType>(input.getType());
int64_t channelAxis = quantizedType.getQuantizedDimension();
+ int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
Value inputShape;
if (isUnranked) {
- std::tie(input, inputShape) =
- flattenUnrankedTensorAroundAxis(builder, loc, input, channelAxis);
+ std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+ builder, loc, input, channelAxis, channelAxisSize);
channelAxis = 1;
}
@@ -341,7 +386,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
// Restore original tensor shape if unranked
if (isUnranked)
- result = restoreUnrankedTensor(builder, loc, result, inputShape);
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
return result;
}
@@ -357,18 +402,19 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
auto resultScalarType = getScalarType(op.getResult().getType());
// Flatten unranked tensor input
- Value storedValue;
+ Value result;
if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
- storedValue = convertPerLayer(rewriter, loc, input, quantizedType);
- } else if (auto quantizedType = dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
- storedValue = convertPerChannel(rewriter, loc, input, quantizedType);
+ result = convertPerLayer(rewriter, loc, input, quantizedType);
+ } else if (auto quantizedType =
+ dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
+ result = convertPerChannel(rewriter, loc, input, quantizedType);
} else {
- llvm_unreachable("unexpected quantized type");
+ llvm_unreachable("unexpected uniform quantized type");
}
// Cast stored value to result quantized value
rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
- op, op.getResult().getType(), storedValue);
+ op, op.getResult().getType(), result);
return success();
}
};
>From f032c472134642dc44ab6134d2e2646ca13ca341 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 17 Jul 2024 11:31:20 -0400
Subject: [PATCH 08/18] Per-layer quantization unit tests
---
.../Quant/Transforms/LowerQuantOps.cpp | 50 +++---
mlir/test/Dialect/Quant/lower-quant-ops.mlir | 165 ++++++++++++++++++
2 files changed, 194 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Dialect/Quant/lower-quant-ops.mlir
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 68a4328128292..e6942899cd638 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -181,8 +182,9 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
}
-Value materializePerChannelZeroPoints(OpBuilder &builder, Location loc,
- UniformQuantizedPerAxisType quantizedType) {
+Value materializePerChannelZeroPoints(
+ OpBuilder &builder, Location loc,
+ UniformQuantizedPerAxisType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
auto zeroPointAttrs = llvm::map_to_vector(
@@ -246,36 +248,42 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
- // Convert scale and zero point to tensors if necessary
+ // Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(
builder, loc, scale, inputType, inputShape);
- zeroPoint = getScalarOrTensorConstant(
- builder, loc, zeroPoint, inputType, inputShape);
- // Convert zero point from storage to expressed type
- auto expressedScalarOrTensorType =
- getScalarOrTensorType(quantizedType.getExpressedType(), inputType);
- zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
- expressedScalarOrTensorType,
- quantizedType.isSigned());
-
- // Scale input and add zero point
+ // Scale input
auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
- auto storedValueAsExpressedType =
- builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
- // Convert to storage type
+ // Skip unnecessary computations if no zero point is given
+ Value storedValueFloat = scaledValue;
+ if (matchPattern(zeroPoint, m_NonZero())) {
+ // Convert zero point to tensor if necessary
+ zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+ inputShape);
+
+ // Convert zero point from storage to expressed type
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+ scale.getType(),
+ quantizedType.isSigned());
+
+ // Add zero point to stored value
+ storedValueFloat =
+ builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+ }
+
+ // Convert stored value to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), inputType);
- auto storedValue = convertFloatToInteger(
- builder, loc, storedValueAsExpressedType, storageScalarOrTensorType,
+ auto storedValueInt = convertFloatToInteger(
+ builder, loc, storedValueFloat, storageScalarOrTensorType,
quantizedType.isSigned());
// Clamp stored value it if the storage type is bound
- storedValue =
- clampScalarOrTensor(builder, loc, storedValue, inputShape, quantizedType);
- return storedValue;
+ auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
+ inputShape, quantizedType);
+ return storedValueClamped;
}
class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
new file mode 100644
index 0000000000000..0e151c514eebc
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @qcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform<i8:f32, 2.000000e+00:10>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8:f32, 2.000000e+00:10>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<i8<-5:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<u8<2:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_FLOAT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias>
+ return %0 : tensor<3x?x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32>
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8>
+
+// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8>
+// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32, 2.0:10>
+func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias>
+ return %0 : tensor<3x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor<?xf32>
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor<?xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<?xf32> to tensor<?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor<?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return %0 : tensor<*x!qalias>
+}
+
>From 3a913979546c071eaae5e681d4313e5abf18e523 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 17 Jul 2024 12:45:58 -0400
Subject: [PATCH 09/18] Completed unit tests for 'quant.qcast'
---
.../Quant/Transforms/LowerQuantOps.cpp | 2 +-
mlir/test/Dialect/Quant/lower-quant-ops.mlir | 117 ++++++++++++++++++
2 files changed, 118 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index e6942899cd638..1ffbd7032ae58 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -258,7 +258,7 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
// Skip unnecessary computations if no zero point is given
Value storedValueFloat = scaledValue;
- if (matchPattern(zeroPoint, m_NonZero())) {
+ if (!matchPattern(zeroPoint, m_Zero())) {
// Convert zero point to tensor if necessary
zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
inputShape);
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 0e151c514eebc..1aa14c615c6e5 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -163,3 +163,120 @@ func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
return %0 : tensor<*x!qalias>
}
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32>
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<4x?x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> {
+ %0 = "quant.qcast"(%arg0) : (tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias>
+ return %0 : tensor<4x?x?x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8>
+
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK: %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8
+// CHECK: linalg.yield %[[STORED_CLAMPED]] : i8
+// CHECK: } -> tensor<4x2x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
+func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
+ %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias>
+ return %0 : tensor<4x2x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index
+// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x3x?xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor<?x3x?xf32>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor<?x3x?xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xf32>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<?x3x?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+ %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
+ return %0 : tensor<*x!qalias>
+}
+
>From b50b2e12ebc9d981ac0d39b3c69d5c0f5f3c0afe Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 09:34:57 -0400
Subject: [PATCH 10/18] Unit test for 0D tensor
---
mlir/test/Dialect/Quant/lower-quant-ops.mlir | 29 ++++++++++++++++++--
1 file changed, 27 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 1aa14c615c6e5..ad3be870916dc 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -86,9 +86,9 @@ func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32>
-// CHECK: %[[STORED_FLOAT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
-// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_FLOAT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
!qalias = !quant.uniform<i8:f32, 2.0:10>
@@ -99,6 +99,31 @@ func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qal
// -----
+// CHECK-LABEL: @qcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
+ %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
+ return %0 : tensor<!qalias>
+}
+
+// -----
+
// CHECK-LABEL: @qcast_per_layer_ranked_bounds
// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
>From 9fc4be12215ad3bb9940203bda6d0114b05fa9c9 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 12:36:41 -0400
Subject: [PATCH 11/18] Support for 'quant.dcast' and unit tests
---
.../Quant/Transforms/LowerQuantOps.cpp | 217 ++++++++++++++++--
mlir/test/Dialect/Quant/lower-quant-ops.mlir | 95 ++++++--
2 files changed, 264 insertions(+), 48 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1ffbd7032ae58..1b2d36dc672cf 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -30,26 +30,6 @@ namespace quant {
namespace {
-//===----------------------------------------------------------------------===//
-// DequantizeCastOp
-//===----------------------------------------------------------------------===//
-
-class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
-public:
- using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- return success();
- }
-};
-
-
-//===----------------------------------------------------------------------===//
-// QuantizeCastOp
-//===----------------------------------------------------------------------===//
-
// If 'inputType' is a tensor, return its element type. If it is a scalar,
// return it as is.
Type getScalarType(Type inputType) {
@@ -243,8 +223,9 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
return builder.create<arith::UIToFPOp>(loc, resultType, input);
}
-// Quantize a floating-point input using the given scale, input shape, and
-// storage type bounds in the given quantized type.
+// Quantize a floating-point input using the given input shape, scale, and
+// zero point. The stored value is clamped using the storage bounds encoded in
+// the given quantized type.
Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
@@ -286,6 +267,196 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
return storedValueClamped;
}
+Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ // Convert scale to tensor if necessary
+ auto inputType = input.getType();
+ scale = getScalarOrTensorConstant(
+ builder, loc, scale, inputType, inputShape);
+
+ // Convert stored value to float
+ auto result = convertIntegerToFloat(
+ builder, loc, input, scale.getType(), quantizedType.isSigned());
+
+ // Skip unnecessary computations if no zero point is given
+ if (!matchPattern(zeroPoint, m_Zero())) {
+ // Convert zero point to tensor if necessary
+ zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+ inputShape);
+
+ // Convert zero point from storage to expressed type
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+ scale.getType(),
+ quantizedType.isSigned());
+
+ // Subtract zero point to stored value
+ result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
+ }
+
+ // Multiply by scale
+ result = builder.create<arith::MulFOp>(loc, result, scale);
+ return result;
+}
+
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+
+ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedType quantizedType) const {
+
+ // Create scale and zero point constants
+ auto expressedType = quantizedType.getExpressedType();
+ auto storageType = quantizedType.getStorageType();
+ auto scaleAttr =
+ builder.getFloatAttr(expressedType, quantizedType.getScale());
+ auto scale =
+ builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+ auto zeroPointAttr =
+ builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+ auto zeroPoint =
+ builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+ auto inputShape = getScalarOrTensorShape(builder, loc, input);
+ return dequantizeScalarOrTensor(builder, loc, input, inputShape, scale,
+ zeroPoint, quantizedType);
+ }
+
+ Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedType quantizedType) const {
+ // Flatten input if unranked
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ Value inputShape;
+ if (isUnranked)
+ std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+
+ // Process ranked tensor
+ auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
+
+ // Restore original shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+ return result;
+ }
+
+ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedPerAxisType quantizedType,
+ int64_t channelAxis) const {
+ auto *context = builder.getContext();
+
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto inputRank = inputType.getRank();
+
+ auto scales = materializePerChannelScales(builder, loc, quantizedType);
+ auto zeroPoints =
+ materializePerChannelZeroPoints(builder, loc, quantizedType);
+
+ auto storageType = quantizedType.getStorageType();
+ auto initShape = tensor::getMixedSizes(builder, loc, input);
+ Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
+ inputRank, utils::IteratorType::parallel);
+ auto channelAxisAffineMap = AffineMap::get(
+ inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+ SmallVector<AffineMap> indexingMaps{
+ builder.getMultiDimIdentityMap(inputRank),
+ channelAxisAffineMap,
+ channelAxisAffineMap,
+ builder.getMultiDimIdentityMap(inputRank)
+ };
+ auto storedValue = builder.create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps,
+ iteratorTypes,
+ [&](OpBuilder& builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto expressedValue = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
+ scale, zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, result);
+ })
+ .getResult(0);
+
+ return storedValue;
+ }
+
+ Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
+ UniformQuantizedPerAxisType quantizedType) const {
+ // Flatten unranked tensor into a 3D ranked tensor if necessary
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ int64_t channelAxis = quantizedType.getQuantizedDimension();
+ int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+ Value inputShape;
+ if (isUnranked) {
+ std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+ builder, loc, input, channelAxis, channelAxisSize);
+ channelAxis = 1;
+ }
+
+ // Work on a ranked tensor
+ auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
+ channelAxis);
+
+ // Restore original tensor shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+ return result;
+ }
+
+public:
+ using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto input = op.getInput();
+ auto quantizedType =
+ cast<QuantizedType>(getScalarType(op.getInput().getType()));
+
+ // Convert quantized input to storage type
+ auto storageScalarOrTensorType =
+ getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
+ input = rewriter.create<quant::StorageCastOp>(
+ loc, storageScalarOrTensorType, input);
+
+ // Flatten unranked tensor input
+ Value result;
+ if (auto uniformQuantizedType =
+ dyn_cast<UniformQuantizedType>(quantizedType)) {
+ result = convertPerLayer(rewriter, loc, input, uniformQuantizedType);
+ } else if (auto uniformQuantizedPerAxisType =
+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
+ result =
+ convertPerChannel(rewriter, loc, input, uniformQuantizedPerAxisType);
+ } else {
+ llvm_unreachable("unexpected quantized type");
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
@@ -417,7 +588,7 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
result = convertPerChannel(rewriter, loc, input, quantizedType);
} else {
- llvm_unreachable("unexpected uniform quantized type");
+ llvm_unreachable("unexpected quantized type");
}
// Cast stored value to result quantized value
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index ad3be870916dc..1030d5b20620b 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -72,6 +72,31 @@ func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
// -----
+// CHECK-LABEL: @qcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
+ %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
+ return %0 : tensor<!qalias>
+}
+
+// -----
+
// CHECK-LABEL: @qcast_per_layer_ranked
// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32>
@@ -99,31 +124,6 @@ func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qal
// -----
-// CHECK-LABEL: @qcast_per_layer_0d
-// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
-
-// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
-
-// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
-// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
-
-// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
-// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
-// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
-// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
-
-// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
-// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
-
-!qalias = !quant.uniform<i8:f32, 2.0:10>
-func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
- %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
- return %0 : tensor<!qalias>
-}
-
-// -----
-
// CHECK-LABEL: @qcast_per_layer_ranked_bounds
// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
@@ -305,3 +305,48 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias>
return %0 : tensor<*x!qalias>
}
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<u8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return %0 : f32
+}
+
>From 5b40f256c6dd31f511da5ff58608e9c28732a0b6 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 14:33:38 -0400
Subject: [PATCH 12/18] Refactored quant.qcast and quant.dcast common code
---
.../Quant/Transforms/LowerQuantOps.cpp | 418 +++++++-----------
1 file changed, 164 insertions(+), 254 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 1b2d36dc672cf..77be3dd11d3ad 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -223,12 +223,11 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
return builder.create<arith::UIToFPOp>(loc, resultType, input);
}
-// Quantize a floating-point input using the given input shape, scale, and
-// zero point. The stored value is clamped using the storage bounds encoded in
-// the given quantized type.
-Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
- ArrayRef<OpFoldResult> inputShape, Value scale,
- Value zeroPoint, QuantizedType quantizedType) {
+// Quantize a scalar or ranked tensor value. The stored value is clamped using
+// the storage bounds encoded in the given quantized type.
+Value quantizeValue(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(
@@ -267,9 +266,10 @@ Value quantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
return storedValueClamped;
}
-Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
- ArrayRef<OpFoldResult> inputShape, Value scale,
- Value zeroPoint, QuantizedType quantizedType) {
+// Dequantize a scalar or ranked tensor value.
+Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(
@@ -299,125 +299,168 @@ Value dequantizeScalarOrTensor(OpBuilder &builder, Location loc, Value input,
return result;
}
+// Convert a scalar or ranked tensor input with the given scale and zero point
+// values.
+//
+// - input
+// Scalar or ranked tensor value.
+//
+// - inputShape
+// If 'input' is a tensor, combination or attributes/values representing its
+// static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - scale
+// Scale as a scalar value.
+//
+// - zeroPoint
+// Zero point as a scalar value.
+//
+// - quantizedType
+// Scalar quantized type of the result ('quant.qcast') or of the input
+// ('quant.dcast').
+//
+Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
+ Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ if (isa<QuantizeCastOp>(op))
+ return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+ quantizedType);
+ if (isa<DequantizeCastOp>(op))
+ return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+ quantizedType);
+ llvm_unreachable("unexpected quant op");
+}
-//===----------------------------------------------------------------------===//
-// DequantizeCastOp
-//===----------------------------------------------------------------------===//
+Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
+ Value input, UniformQuantizedType quantizedType) {
+
+ // Create scale and zero point constants
+ auto expressedType = quantizedType.getExpressedType();
+ auto storageType = quantizedType.getStorageType();
+ auto scaleAttr =
+ builder.getFloatAttr(expressedType, quantizedType.getScale());
+ auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+ auto zeroPointAttr =
+ builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+ auto zeroPoint =
+ builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+ auto inputShape = getScalarOrTensorShape(builder, loc, input);
+ return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
+ quantizedType);
+}
-class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
-
- Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedType quantizedType) const {
-
- // Create scale and zero point constants
- auto expressedType = quantizedType.getExpressedType();
- auto storageType = quantizedType.getStorageType();
- auto scaleAttr =
- builder.getFloatAttr(expressedType, quantizedType.getScale());
- auto scale =
- builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
- auto zeroPointAttr =
- builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
- auto zeroPoint =
- builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
-
- auto inputShape = getScalarOrTensorShape(builder, loc, input);
- return dequantizeScalarOrTensor(builder, loc, input, inputShape, scale,
+Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
+ Value input, UniformQuantizedType quantizedType) {
+ // Flatten input if unranked
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ Value inputShape;
+ if (isUnranked)
+ std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+
+ // Process ranked tensor
+ auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
+
+ // Restore original shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+ return result;
+}
+
+Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
+ Value input,
+ UniformQuantizedPerAxisType quantizedType,
+ int64_t channelAxis) {
+ auto *context = builder.getContext();
+
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto inputRank = inputType.getRank();
+
+ auto scales = materializePerChannelScales(builder, loc, quantizedType);
+ auto zeroPoints =
+ materializePerChannelZeroPoints(builder, loc, quantizedType);
+
+ auto storageType = quantizedType.getStorageType();
+ auto initShape = tensor::getMixedSizes(builder, loc, input);
+ Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
+ inputRank, utils::IteratorType::parallel);
+ auto channelAxisAffineMap = AffineMap::get(
+ inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+ SmallVector<AffineMap> indexingMaps{
+ builder.getMultiDimIdentityMap(inputRank),
+ channelAxisAffineMap,
+ channelAxisAffineMap,
+ builder.getMultiDimIdentityMap(inputRank)
+ };
+ auto storedValue = builder.create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps,
+ iteratorTypes,
+ [&](OpBuilder& builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto expressedValue = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result = convertRanked(builder, loc, op, expressedValue, {}, scale,
zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, result);
+ })
+ .getResult(0);
+
+ return storedValue;
+}
+
+Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
+ Value input,
+ UniformQuantizedPerAxisType quantizedType) {
+ // Flatten unranked tensor into a 3D ranked tensor if necessary
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ int64_t channelAxis = quantizedType.getQuantizedDimension();
+ int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+ Value inputShape;
+ if (isUnranked) {
+ std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+ builder, loc, input, channelAxis, channelAxisSize);
+ channelAxis = 1;
}
- Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedType quantizedType) const {
- // Flatten input if unranked
- bool isUnranked = isa<UnrankedTensorType>(input.getType());
- Value inputShape;
- if (isUnranked)
- std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+ // Work on a ranked tensor
+ auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
+ channelAxis);
- // Process ranked tensor
- auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
+ // Restore original tensor shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
- // Restore original shape if unranked
- if (isUnranked)
- result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+ return result;
+}
- return result;
- }
+Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
+ Value input, Type quantizedType) {
+ if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
+ return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
- Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedPerAxisType quantizedType,
- int64_t channelAxis) const {
- auto *context = builder.getContext();
-
- auto inputType = cast<RankedTensorType>(input.getType());
- auto inputRank = inputType.getRank();
-
- auto scales = materializePerChannelScales(builder, loc, quantizedType);
- auto zeroPoints =
- materializePerChannelZeroPoints(builder, loc, quantizedType);
-
- auto storageType = quantizedType.getStorageType();
- auto initShape = tensor::getMixedSizes(builder, loc, input);
- Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
-
- SmallVector<utils::IteratorType> iteratorTypes(
- inputRank, utils::IteratorType::parallel);
- auto channelAxisAffineMap = AffineMap::get(
- inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
- SmallVector<AffineMap> indexingMaps{
- builder.getMultiDimIdentityMap(inputRank),
- channelAxisAffineMap,
- channelAxisAffineMap,
- builder.getMultiDimIdentityMap(inputRank)
- };
- auto storedValue = builder.create<linalg::GenericOp>(
- loc,
- init.getType(), // resultType
- ValueRange{input, scales, zeroPoints}, // inputs
- ValueRange{init}, // outputs
- indexingMaps,
- iteratorTypes,
- [&](OpBuilder& builder, Location loc, ValueRange args) {
- assert(args.size() == 4);
- auto expressedValue = args[0];
- auto scale = args[1];
- auto zeroPoint = args[2];
-
- auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
- scale, zeroPoint, quantizedType);
-
- builder.create<linalg::YieldOp>(loc, result);
- })
- .getResult(0);
-
- return storedValue;
- }
+ if (auto uniformQuantizedPerAxisType =
+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
+ return convertPerChannel(builder, loc, op, input,
+ uniformQuantizedPerAxisType);
- Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedPerAxisType quantizedType) const {
- // Flatten unranked tensor into a 3D ranked tensor if necessary
- bool isUnranked = isa<UnrankedTensorType>(input.getType());
- int64_t channelAxis = quantizedType.getQuantizedDimension();
- int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
- Value inputShape;
- if (isUnranked) {
- std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
- builder, loc, input, channelAxis, channelAxisSize);
- channelAxis = 1;
- }
-
- // Work on a ranked tensor
- auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
- channelAxis);
-
- // Restore original tensor shape if unranked
- if (isUnranked)
- result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
-
- return result;
- }
+ llvm_unreachable("unexpected quantized type");
+}
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
-public:
+struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
LogicalResult
@@ -434,19 +477,7 @@ class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeC
input = rewriter.create<quant::StorageCastOp>(
loc, storageScalarOrTensorType, input);
- // Flatten unranked tensor input
- Value result;
- if (auto uniformQuantizedType =
- dyn_cast<UniformQuantizedType>(quantizedType)) {
- result = convertPerLayer(rewriter, loc, input, uniformQuantizedType);
- } else if (auto uniformQuantizedPerAxisType =
- dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
- result =
- convertPerChannel(rewriter, loc, input, uniformQuantizedPerAxisType);
- } else {
- llvm_unreachable("unexpected quantized type");
- }
-
+ auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
rewriter.replaceOp(op, result);
return success();
}
@@ -457,120 +488,7 @@ class DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeC
// QuantizeCastOp
//===----------------------------------------------------------------------===//
-class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
-
- Value convertPerLayerRanked(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedType quantizedType) const {
-
- // Create scale and zero point constants
- auto expressedType = quantizedType.getExpressedType();
- auto storageType = quantizedType.getStorageType();
- auto scaleAttr =
- builder.getFloatAttr(expressedType, quantizedType.getScale());
- auto scale =
- builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
- auto zeroPointAttr =
- builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
- auto zeroPoint =
- builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
-
- auto inputShape = getScalarOrTensorShape(builder, loc, input);
- return quantizeScalarOrTensor(builder, loc, input, inputShape, scale,
- zeroPoint, quantizedType);
- }
-
- Value convertPerLayer(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedType quantizedType) const {
- // Flatten input if unranked
- bool isUnranked = isa<UnrankedTensorType>(input.getType());
- Value inputShape;
- if (isUnranked)
- std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
-
- // Process ranked tensor
- auto result = convertPerLayerRanked(builder, loc, input, quantizedType);
-
- // Restore original shape if unranked
- if (isUnranked)
- result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
-
- return result;
- }
-
- Value convertPerChannelRanked(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedPerAxisType quantizedType,
- int64_t channelAxis) const {
- auto *context = builder.getContext();
-
- auto inputType = cast<RankedTensorType>(input.getType());
- auto inputRank = inputType.getRank();
-
- auto scales = materializePerChannelScales(builder, loc, quantizedType);
- auto zeroPoints =
- materializePerChannelZeroPoints(builder, loc, quantizedType);
-
- auto storageType = quantizedType.getStorageType();
- auto initShape = tensor::getMixedSizes(builder, loc, input);
- Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
-
- SmallVector<utils::IteratorType> iteratorTypes(
- inputRank, utils::IteratorType::parallel);
- auto channelAxisAffineMap = AffineMap::get(
- inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
- SmallVector<AffineMap> indexingMaps{
- builder.getMultiDimIdentityMap(inputRank),
- channelAxisAffineMap,
- channelAxisAffineMap,
- builder.getMultiDimIdentityMap(inputRank)
- };
- auto storedValue = builder.create<linalg::GenericOp>(
- loc,
- init.getType(), // resultType
- ValueRange{input, scales, zeroPoints}, // inputs
- ValueRange{init}, // outputs
- indexingMaps,
- iteratorTypes,
- [&](OpBuilder& builder, Location loc, ValueRange args) {
- assert(args.size() == 4);
- auto expressedValue = args[0];
- auto scale = args[1];
- auto zeroPoint = args[2];
-
- auto result = quantizeScalarOrTensor(builder, loc, expressedValue, {},
- scale, zeroPoint, quantizedType);
-
- builder.create<linalg::YieldOp>(loc, result);
- })
- .getResult(0);
-
- return storedValue;
- }
-
- Value convertPerChannel(OpBuilder &builder, Location loc, Value input,
- UniformQuantizedPerAxisType quantizedType) const {
- // Flatten unranked tensor into a 3D ranked tensor if necessary
- bool isUnranked = isa<UnrankedTensorType>(input.getType());
- int64_t channelAxis = quantizedType.getQuantizedDimension();
- int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
- Value inputShape;
- if (isUnranked) {
- std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
- builder, loc, input, channelAxis, channelAxisSize);
- channelAxis = 1;
- }
-
- // Work on a ranked tensor
- auto result = convertPerChannelRanked(builder, loc, input, quantizedType,
- channelAxis);
-
- // Restore original tensor shape if unranked
- if (isUnranked)
- result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
-
- return result;
- }
-
-public:
+struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
LogicalResult
@@ -578,18 +496,10 @@ class QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastO
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput();
- auto resultScalarType = getScalarType(op.getResult().getType());
+ auto quantizedType = getScalarType(op.getResult().getType());
// Flatten unranked tensor input
- Value result;
- if (auto quantizedType = dyn_cast<UniformQuantizedType>(resultScalarType)) {
- result = convertPerLayer(rewriter, loc, input, quantizedType);
- } else if (auto quantizedType =
- dyn_cast<UniformQuantizedPerAxisType>(resultScalarType)) {
- result = convertPerChannel(rewriter, loc, input, quantizedType);
- } else {
- llvm_unreachable("unexpected quantized type");
- }
+ auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
// Cast stored value to result quantized value
rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
>From 6eabe11b9880d2e22b2d4077d4f2639845219cce Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 22 Jul 2024 17:29:17 -0400
Subject: [PATCH 13/18] Unit test for 'quant.dcast' lowering with bug fixes
---
.../Quant/Transforms/LowerQuantOps.cpp | 172 ++++++++++--
mlir/test/Dialect/Quant/lower-quant-ops.mlir | 255 ++++++++++++++----
2 files changed, 360 insertions(+), 67 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 77be3dd11d3ad..4adeb9218ff8e 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -76,6 +76,19 @@ Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
return tensorConstant;
}
+// Reshape an unranked tensor into a 1D ranked tensor.
+//
+// - input
+// Unranked tensor.
+//
+// Return values:
+//
+// - flatInput
+// 1D ranked, dynamically shaped tensor.
+//
+// - inputShape
+// 1D extent tensor containing the shape of the original unranked input.
+//
std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
Value input) {
// Get unranked input shape and total size
@@ -100,6 +113,28 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
return std::make_pair(flatInput, inputShape);
}
+// Reshape an unranked tensor into a 3D ranked tensor where the central
+// dimension of the result tensor corresponds to dimension 'axis' of the input
+// tensor.
+//
+// - input
+// Unranked tensor.
+//
+// - axis
+// Index of the input dimension around which other input dimiensions will be
+// collapsed.
+//
+// - axisSize
+// Size of input dimension 'axis'.
+//
+// Return values:
+//
+// - flatInput
+// 3D ranked tensor of shape [?, axisSize, ?].
+//
+// - inputShape
+// 1D extent tensor containing the shape of the original unranked input.
+//
std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
Location loc,
Value input,
@@ -142,6 +177,14 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
return std::make_pair(flatInput, inputShape);
}
+// Reshape an input tensor into its original unranked shape.
+//
+// - input
+// Ranked tensor.
+//
+// - inputShape
+// 1D extent tensor.
+//
Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
Value inputShape) {
auto inputType = cast<RankedTensorType>(input.getType());
@@ -150,6 +193,15 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
}
+// Create a tensor constant containing all scales in a per-channel quantized
+// type. Example:
+//
+// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
+//
Value materializePerChannelScales(OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
auto scales = quantizedType.getScales();
@@ -162,6 +214,15 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
}
+// Create a tensor constant containing all zero points in a per-channel
+// quantized type. Example:
+//
+// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
+//
Value materializePerChannelZeroPoints(
OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
@@ -178,6 +239,19 @@ Value materializePerChannelZeroPoints(
return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
}
+// Clamp the given scalar or tensor input using the storage bounds encoded in
+// the given quantized type, if present.
+//
+// - input
+// Scalar or ranked tensor input. The element type must match the storage type
+// of 'quantizedType'.
+//
+// - inputShape
+// If 'input' is a tensor, combination of attributes/values representing its
+// static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - quantizedType
+// Per-axis or per-channel quantized type.
Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape,
QuantizedType quantizedType) {
@@ -209,6 +283,7 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
return input;
}
+// Emit op 'arith.fptosi' or 'arith.fptoui'.
Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
@@ -216,6 +291,7 @@ Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
return builder.create<arith::FPToUIOp>(loc, resultType, input);
}
+// Emit op 'arith.sitofp' or 'arith.uitofp'.
Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
@@ -225,6 +301,8 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
// Quantize a scalar or ranked tensor value. The stored value is clamped using
// the storage bounds encoded in the given quantized type.
+//
+// See function 'convertRanked()' below for a description of the arguments.
Value quantizeValue(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
@@ -266,7 +344,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
return storedValueClamped;
}
-// Dequantize a scalar or ranked tensor value.
+// Dequantize a scalar or ranked tensor input.
+//
+// See function 'convertRanked()' below for a description of the arguments.
Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
@@ -310,10 +390,10 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
// static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - scale
-// Scale as a scalar value.
+// Scale as a floating-point scalar value.
//
// - zeroPoint
-// Zero point as a scalar value.
+// Zero point as an integer scalar value.
//
// - quantizedType
// Scalar quantized type of the result ('quant.qcast') or of the input
@@ -331,9 +411,20 @@ Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
llvm_unreachable("unexpected quant op");
}
+// Convert an operation using per-layer quantization with a scalar or ranked
+// tensor input.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar or ranked tensor.
+//
+// - quantizedType
+// Per-layer quantized type.
+//
Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
Value input, UniformQuantizedType quantizedType) {
-
// Create scale and zero point constants
auto expressedType = quantizedType.getExpressedType();
auto storageType = quantizedType.getStorageType();
@@ -350,6 +441,17 @@ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
quantizedType);
}
+// Convert an operation using per-layer quantization.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+// Per-layer quantized type.
+//
Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
Value input, UniformQuantizedType quantizedType) {
// Flatten input if unranked
@@ -368,6 +470,18 @@ Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
return result;
}
+// Convert an operation using per-channel quantization and a scalar or ranked
+// tensor as an input.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar or ranked tensor.
+//
+// - quantizedType
+// Per-channel quantized type.
+//
Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedPerAxisType quantizedType,
@@ -381,9 +495,11 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
auto zeroPoints =
materializePerChannelZeroPoints(builder, loc, quantizedType);
- auto storageType = quantizedType.getStorageType();
+ auto elementType = isa<FloatType>(inputType.getElementType())
+ ? quantizedType.getStorageType()
+ : quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
- Value init = builder.create<tensor::EmptyOp>(loc, initShape, storageType);
+ Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
@@ -395,7 +511,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
channelAxisAffineMap,
builder.getMultiDimIdentityMap(inputRank)
};
- auto storedValue = builder.create<linalg::GenericOp>(
+ auto result = builder.create<linalg::GenericOp>(
loc,
init.getType(), // resultType
ValueRange{input, scales, zeroPoints}, // inputs
@@ -404,20 +520,31 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
iteratorTypes,
[&](OpBuilder& builder, Location loc, ValueRange args) {
assert(args.size() == 4);
- auto expressedValue = args[0];
+ auto input = args[0];
auto scale = args[1];
auto zeroPoint = args[2];
- auto result = convertRanked(builder, loc, op, expressedValue, {}, scale,
+ auto result = convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
builder.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
- return storedValue;
+ return result;
}
+// Convert an operation using per-channel quantization.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+// Per-channel quantized type.
+//
Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedPerAxisType quantizedType) {
@@ -443,6 +570,19 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
return result;
}
+// Convert a quantization operation.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor, or unranked tensor. The element type matches
+// the storage type (quant.dcast) or expressed type (quant.qcast) of
+// 'quantizedType'.
+//
+// - quantizedType
+// Per-layer or per-channel quantized type.
+//
Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
Value input, Type quantizedType) {
if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
@@ -456,10 +596,7 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
llvm_unreachable("unexpected quantized type");
}
-//===----------------------------------------------------------------------===//
-// DequantizeCastOp
-//===----------------------------------------------------------------------===//
-
+// Lowering pattern for 'quant.dcast'
struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
@@ -478,16 +615,13 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
loc, storageScalarOrTensorType, input);
auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
+
rewriter.replaceOp(op, result);
return success();
}
};
-
-//===----------------------------------------------------------------------===//
-// QuantizeCastOp
-//===----------------------------------------------------------------------===//
-
+// Lowering pattern for 'quant.qcast'
struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 1030d5b20620b..6bba9f5c03772 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -1,5 +1,209 @@
// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s
+// CHECK-LABEL: @dcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<u8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<i8>
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<i8> to tensor<f32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<f32>
+// CHECK: return %[[EXPRESSED]] : tensor<f32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_0d(%arg0: tensor<!qalias>) -> tensor<f32> {
+ %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<3x?x5xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> {
+ %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32>
+ return %0 : tensor<3x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<*xi8>
+// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor<?xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<?xi8> to tensor<?xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>> to tensor<4x?x?x5xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<4x?x?x5xf32>
+// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> {
+ %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32>
+ return %0 : tensor<4x?x?x5xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>> to tensor<*xi8>
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor<?x3x?xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?x3x?xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor<?x3x?xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xi8>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<?x3x?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @qcast_per_layer_scalar
// CHECK-SAME: %[[ARG_0:.*]]: f32
@@ -219,7 +423,7 @@ func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> {
- %0 = "quant.qcast"(%arg0) : (tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias>
+ %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias>
return %0 : tensor<4x?x?x5x!qalias>
}
@@ -253,7 +457,7 @@ func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x
!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
- %0 = "quant.qcast"(%arg0) : (tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias>
+ %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias>
return %0 : tensor<4x2x5x!qalias>
}
@@ -301,52 +505,7 @@ func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4
!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
- %0 = "quant.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!qalias>
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
return %0 : tensor<*x!qalias>
}
-// -----
-
-// CHECK-LABEL: @dcast_per_layer_scalar
-// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
-
-// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
-
-// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
-// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
-// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
-
-// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
-// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
-// CHECK: return %[[EXPRESSED]] : f32
-
-!qalias = !quant.uniform<i8:f32, 2.0:10>
-func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
- %0 = quant.dcast %arg0 : !qalias to f32
- return %0 : f32
-}
-
-// -----
-
-// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
-// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
-
-// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
-
-// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
-
-// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
-// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
-
-// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
-// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
-// CHECK: return %[[EXPRESSED]] : f32
-
-!qalias = !quant.uniform<u8:f32, 2.0:10>
-func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
- %0 = quant.dcast %arg0 : !qalias to f32
- return %0 : f32
-}
-
>From 647b8ec64d01af7606c29a0f90ad98b49ba8fa95 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 23 Jul 2024 15:48:56 -0400
Subject: [PATCH 14/18] Verifiers for 'quant.dcast' and 'quant.qcast'
---
mlir/include/mlir/Dialect/Quant/IR/Quant.h | 2 +
.../mlir/Dialect/Quant/IR/QuantBase.td | 96 ++++++-----
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 40 ++++-
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 105 +++++++++++-
mlir/test/Dialect/Quant/invalid.mlir | 161 ++++++++++++++++++
mlir/test/Dialect/Quant/ops.mlir | 96 +++++++++++
6 files changed, 443 insertions(+), 57 deletions(-)
create mode 100644 mlir/test/Dialect/Quant/invalid.mlir
create mode 100644 mlir/test/Dialect/Quant/ops.mlir
diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
index c5ca88ec69795..11a969a3ee519 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -24,7 +24,9 @@
namespace mlir {
namespace quant {
+class QuantizedType;
class UniformQuantizedType;
+class UniformQuantizedPerAxisType;
} // namespace quant
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index e465d855c1986..8ef89a0d1393b 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -22,53 +22,63 @@ def Quant_Dialect : Dialect {
let useDefaultTypePrinterParser = 1;
}
+
//===----------------------------------------------------------------------===//
-// Quantization type definitions
+// Type definitions
//===----------------------------------------------------------------------===//
-class quant_TypedPrimitiveOrContainer<Type etype> :
- Type<Or<[etype.predicate,
- TensorOf<[etype]>.predicate,
- VectorOf<[etype]>.predicate]>,
- "primitive/tensor/vector of " # etype.summary>;
+class quant_ScalarOrTensorOf<Type etype> :
+ Type<Or<[etype.predicate, TensorOf<[etype]>.predicate]>,
+ "scalar or tensor of " # etype.summary>;
-// An implementation of QuantizedType.
def quant_QuantizedType :
- Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "QuantizedType">;
-
-// A primitive type that can represent a real value. This is either a
-// floating point value or a quantized type.
-def quant_RealPrimitiveType :
- Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
- "real valued primitive (float or quantized type)">;
-
-// A primitive type that can represent a storage value. This is either an
-// integer or quantized type.
-def quant_StoragePrimitiveType :
- Type<Or<[AnySignlessInteger.predicate, quant_QuantizedType.predicate]>,
- "quantized storage primitive (integer or quantized type)">;
-
-// A primitive or container of RealPrimitiveType.
-def quant_RealValueType :
- quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
-
-// A primitive or container of StoragePrimitiveType.
-def quant_StorageValueType :
- quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
-
-// Either a real valued or storage primitive or container type.
-def quant_RealOrStorageValueType :
- Type<Or<[quant_RealValueType.predicate, quant_StorageValueType.predicate]>,
- "real valued or storage primitive or container type">;
-
-// An implementation of UniformQuantizedType.
-def quant_UniformQuantizedType :
- DialectType<Quant_Dialect,
- CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
- "UniformQuantizedType">;
-
-// Predicate for detecting a container or primitive of UniformQuantizedType.
-def quant_UniformQuantizedValueType :
- quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
+ Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "quantized type">;
+
+def quant_ScalarType :
+ Type<Or<[
+ AnyInteger.predicate,
+ AnyFloat.predicate,
+ quant_QuantizedType.predicate
+ ]>, "integer, float, or quantized scalar">;
+
+def quant_IntegerOrQuantizedType :
+ Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>>;
+
+def quant_FloatScalarOrTensor :
+ quant_ScalarOrTensorOf<AnyFloat>;
+
+def quant_IntegerScalarOrTensor :
+ quant_ScalarOrTensorOf<AnyInteger>;
+
+def quant_QuantizedScalarOrTensor :
+ quant_ScalarOrTensorOf<quant_QuantizedType>;
+
+def quant_IntegerOrQuantizedScalarOrTensor :
+ quant_ScalarOrTensorOf<quant_IntegerOrQuantizedType>;
+
+
+//===----------------------------------------------------------------------===//
+// Traits
+//===----------------------------------------------------------------------===//
+
+def quant_SameScalarOrTensorShape :
+ PredOpTrait<
+ "input and result are both scalars or both tensors with matching shape",
+ Or<[
+ And<[
+ TypeIsPred<"input", quant_ScalarType>,
+ TypeIsPred<"result", quant_ScalarType>
+ ]>,
+ And<[
+ TypeIsPred<"input", AnyUnrankedTensor>,
+ TypeIsPred<"result", AnyUnrankedTensor>
+ ]>,
+ And<[
+ TypeIsPred<"input", AnyRankedTensor>,
+ TypeIsPred<"result", AnyRankedTensor>,
+ AllShapesMatch<["input", "result"]>.predicate
+ ]>
+ ]>
+ >;
#endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 7a6d270dbb6e9..3a24c8148d1f5 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -28,7 +28,9 @@ class quant_Op<string mnemonic, list<Trait> traits> :
// Quantization casts
//===----------------------------------------------------------------------===//
-def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
+def quant_DequantizeCastOp : quant_Op<"dcast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
let summary = "convert back from a quantized to quantizable (expressed) type operation";
let description = [{
A DequantizeCast op `dcast` represents the inverse of a `qcast`,
@@ -42,12 +44,22 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
all operands to ops that must operate with the expressed type (typically
math ops prior to lowering to target-specific, quantized kernels).
}];
- let arguments = (ins quant_RealValueType:$input);
- let results = (outs quant_RealValueType:$result);
+ let arguments = (ins quant_QuantizedScalarOrTensor:$input);
+ let results = (outs quant_FloatScalarOrTensor:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ /// Return the float type of the scalar or tensor result.
+ FloatType getFloatType();
+
+ /// Return the quantized type of the scalar or tensor input.
+ quant::QuantizedType getQuantizedType();
+ }];
}
-def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
+def quant_QuantizeCastOp : quant_Op<"qcast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
let summary = "convert a quantizable type to a quantized type";
let description = [{
A QuantizeCast `qcast` represents a potential type shift from a quantizable
@@ -71,12 +83,22 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
it is legal to use a quantized representation (but is not known to be
acceptable).
}];
- let arguments = (ins quant_RealValueType:$input);
- let results = (outs quant_RealValueType:$result);
+ let arguments = (ins quant_FloatScalarOrTensor:$input);
+ let results = (outs quant_QuantizedScalarOrTensor:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ /// Return the float type of the scalar or tensor input.
+ FloatType getFloatType();
+
+ /// Return the quantized type of the scalar or tensor result.
+ quant::QuantizedType getQuantizedType();
+ }];
}
-def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
+def quant_StorageCastOp : quant_Op<"scast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
let description = [{
A StorageCast `scast` represents a cast from or to a type based on the
@@ -97,8 +119,8 @@ def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
```
}];
- let arguments = (ins quant_RealOrStorageValueType:$input);
- let results = (outs quant_RealOrStorageValueType:$result);
+ let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
+ let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index e04ca7eb7e715..ef66ccf3d9e01 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
@@ -28,13 +29,70 @@ namespace quant {
namespace {
-Type getPrimitiveType(Type ty) {
- if (auto tensorType = dyn_cast<TensorType>(ty))
- return tensorType.getElementType();
- return ty;
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+// Any quantized type. Any quantized type with no per-axis quantization is
+// ignored.
+//
+// - containerType
+// Original input or result type of the operation using the provided quantized
+// type. Used to ensure that the quantized type appears within a tensor and
+// that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+ QuantizedType quantizedType,
+ Type containerType) {
+ auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+ if (!quantizedPerAxisType)
+ return success();
+
+ auto tensorType = dyn_cast<TensorType>(containerType);
+ if (!tensorType)
+ return op->emitError("scalar types may not use per-axis quantization");
+
+ if (!tensorType.hasRank())
+ return success();
+
+ int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+ if (quantizedDimension >= tensorType.getRank())
+ return op->emitError("quantized dimension must be less than tensor rank");
+
+ int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
+ if (quantizedDimensionSize != ShapedType::kDynamic &&
+ quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+ return op->emitError(
+ "quantized dimension size does not match number of scales");
+
+ return success();
+}
+
+// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
+//
+// - quantizedType
+// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
+// whether as a primitive type or in a tensor.
+//
+// - floatType
+// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
+// whether as a primitive type or in a tensor.
+//
+// - containerType
+// Type of original input or result.
+//
+LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
+ FloatType floatType, Type containerType) {
+ if (quantizedType.getExpressedType() != floatType)
+ return op->emitError(
+ "expressed type in quantized type expected to match float type");
+
+ if (failed(verifyPerAxisQuantization(op, quantizedType, containerType)))
+ return failure();
+
+ return success();
}
-} // namespace
+} // namespace
//===----------------------------------------------------------------------===//
@@ -52,6 +110,24 @@ void QuantDialect::initialize() {
}
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DequantizeCastOp::verify() {
+ return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+ getInput().getType());
+}
+
+FloatType DequantizeCastOp::getFloatType() {
+ return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+QuantizedType DequantizeCastOp::getQuantizedType() {
+ return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+
//===----------------------------------------------------------------------===//
// StorageCastOp
//===----------------------------------------------------------------------===//
@@ -65,6 +141,25 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
return srcScastOp.getInput();
}
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult QuantizeCastOp::verify() {
+ return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+ getInput().getType());
+}
+
+FloatType QuantizeCastOp::getFloatType() {
+ return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+QuantizedType QuantizeCastOp::getQuantizedType() {
+ return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+
} // namespace quant
} // namespace mlir
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
new file mode 100644
index 0000000000000..9976ce5d00d65
--- /dev/null
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func.func @dcast_invalid_input(%arg0: f32) {
+ // expected-error at +1 {{operand #0 must be scalar or tensor of quantized type}}
+ %0 = quant.dcast %arg0 : f32 to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_invalid_result(%arg0: !qalias) {
+ // expected-error at +1 {{result #0 must be scalar or tensor of floating-point}}
+ %0 = quant.dcast %arg0 : !qalias to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.dcast %arg0 : !qalias to tensor<f32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor<!qalias>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<*xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<?x3xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_float_type_mismatch(%arg0: !qalias) {
+ // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+ %0 = quant.dcast %arg0 : !qalias to f64
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_scalar(%arg0: !qalias) {
+ // expected-error at +1 {{scalar types may not use per-axis quantization}}
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) {
+ // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+ %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) {
+ // expected-error at +1 {{quantized dimension size does not match number of scales}}
+ %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32>
+ return
+}
+
+// -----
+
+func.func @qcast_invalid_input(%arg0: f32) {
+ // expected-error at +1 {{result #0 must be scalar or tensor of quantized type}}
+ %0 = quant.qcast %arg0 : f32 to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_invalid_result(%arg0: !qalias) {
+ // expected-error at +1 {{operand #0 must be scalar or tensor of floating-point}}
+ %0 = quant.qcast %arg0 : !qalias to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_scalar_tensor(%arg0: tensor<f32>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.qcast %arg0 : tensor<f32> to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor<f32>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.qcast %arg0 : tensor<f32> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<?x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_float_type_mismatch(%arg0: f64) {
+ // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+ %0 = quant.qcast %arg0 : f64 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_scalar(%arg0: f32) {
+ // expected-error at +1 {{scalar types may not use per-axis quantization}}
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) {
+ // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+ %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) {
+ // expected-error at +1 {{quantized dimension size does not match number of scales}}
+ %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias>
+ return
+}
+
+
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
new file mode 100644
index 0000000000000..ab3d6decfb248
--- /dev/null
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_scalar(%arg0: !qalias) {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_unranked(%arg0: tensor<*x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_dynamic(%arg0: tensor<?x?x?x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<?x?x?x!qalias> to tensor<?x?x?xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_scalar(%arg0: f32) {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) {
+ %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_unranked(%arg0: tensor<*xf32>) {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) {
+ %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_dynamic(%arg0: tensor<?x?x?xf32>) {
+ %0 = quant.qcast %arg0 : tensor<?x?x?xf32> to tensor<?x?x?x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return
+}
+
>From 09278de0b804824b0e6c1062e31d1a858549d23c Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 23 Jul 2024 17:56:36 -0400
Subject: [PATCH 15/18] Verifier for 'quant.scast' and unit tests
---
.../mlir/Dialect/Quant/IR/QuantBase.td | 28 +++++-
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 11 ++-
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 65 +++++++++----
mlir/test/Dialect/Quant/invalid.mlir | 97 +++++++++++++++++++
mlir/test/Dialect/Quant/ops.mlir | 55 +++++++++++
5 files changed, 233 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index 8ef89a0d1393b..d81838db3dc1a 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -36,19 +36,24 @@ def quant_QuantizedType :
def quant_ScalarType :
Type<Or<[
- AnyInteger.predicate,
+ AnySignlessInteger.predicate,
AnyFloat.predicate,
quant_QuantizedType.predicate
- ]>, "integer, float, or quantized scalar">;
+ ]>,
+ "signless integer, float, or quantized scalar">;
def quant_IntegerOrQuantizedType :
- Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>>;
+ Type<Or<[
+ AnySignlessInteger.predicate,
+ quant_QuantizedType.predicate
+ ]>,
+ "signless integer or quantized type">;
def quant_FloatScalarOrTensor :
quant_ScalarOrTensorOf<AnyFloat>;
def quant_IntegerScalarOrTensor :
- quant_ScalarOrTensorOf<AnyInteger>;
+ quant_ScalarOrTensorOf<AnySignlessInteger>;
def quant_QuantizedScalarOrTensor :
quant_ScalarOrTensorOf<quant_QuantizedType>;
@@ -81,4 +86,19 @@ def quant_SameScalarOrTensorShape :
]>
>;
+def quant_IntegerAndQuantizedCombination :
+ PredOpTrait<
+ "input must be integer and result must be quantized, or vice versa",
+ Or<[
+ And<[
+ TypeIsPred<"input", quant_QuantizedScalarOrTensor>,
+ TypeIsPred<"result", quant_IntegerScalarOrTensor>
+ ]>,
+ And<[
+ TypeIsPred<"input", quant_IntegerScalarOrTensor>,
+ TypeIsPred<"result", quant_QuantizedScalarOrTensor>
+ ]>
+ ]>
+ >;
+
#endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 3a24c8148d1f5..5dab02e8e1ee5 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -98,7 +98,8 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [
def quant_StorageCastOp : quant_Op<"scast", [
Pure,
- quant_SameScalarOrTensorShape]> {
+ quant_SameScalarOrTensorShape,
+ quant_IntegerAndQuantizedCombination]> {
let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
let description = [{
A StorageCast `scast` represents a cast from or to a type based on the
@@ -122,7 +123,15 @@ def quant_StorageCastOp : quant_Op<"scast", [
let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// Return the integer type used either in the input or the result.
+ IntegerType getIntegerType();
+
+ /// Return the quantized type used either in the input or the result.
+ quant::QuantizedType getQuantizedType();
+ }];
}
#endif // QUANT_OPS
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index ef66ccf3d9e01..f722eb8e30806 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -86,10 +86,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
return op->emitError(
"expressed type in quantized type expected to match float type");
- if (failed(verifyPerAxisQuantization(op, quantizedType, containerType)))
- return failure();
-
- return success();
+ // Veriy integrity of per-axis quantization information, if present.
+ return verifyPerAxisQuantization(op, quantizedType, containerType);
}
} // namespace
@@ -128,20 +126,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() {
}
-//===----------------------------------------------------------------------===//
-// StorageCastOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
- // Matches x -> [scast -> scast] -> y, replacing the second scast with the
- // value of x if the casts invert each other.
- auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
- if (!srcScastOp || srcScastOp.getInput().getType() != getType())
- return OpFoldResult();
- return srcScastOp.getInput();
-}
-
-
//===----------------------------------------------------------------------===//
// QuantizeCastOp
//===----------------------------------------------------------------------===//
@@ -160,6 +144,51 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
}
+//===----------------------------------------------------------------------===//
+// StorageCastOp
+//===----------------------------------------------------------------------===//
+
+IntegerType StorageCastOp::getIntegerType() {
+ auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+ if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
+ return integerType;
+
+ auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+ return cast<IntegerType>(resultScalarType);
+}
+
+QuantizedType StorageCastOp::getQuantizedType() {
+ auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+ if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
+ return quantizedType;
+
+ auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+ return cast<QuantizedType>(resultScalarType);
+}
+
+OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> [scast -> scast] -> y, replacing the second scast with the
+ // value of x if the casts invert each other.
+ auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+ if (!srcScastOp || srcScastOp.getInput().getType() != getType())
+ return OpFoldResult();
+ return srcScastOp.getInput();
+}
+
+LogicalResult StorageCastOp::verify() {
+ auto quantizedType = getQuantizedType();
+ auto integerType = getIntegerType();
+ if (quantizedType.getStorageType() != integerType)
+ return emitError(
+ "storage type in quantized type expected to match integer type");
+
+ // Verify integrity of per-axis quantization information, if available. While
+ // the quantization type may appear in the input or the result, their tensor
+ // shapes are guaranteed to be identical at this point.
+ return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
+}
+
+
} // namespace quant
} // namespace mlir
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
index 9976ce5d00d65..ba3a8e312d96e 100644
--- a/mlir/test/Dialect/Quant/invalid.mlir
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -158,4 +158,101 @@ func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) {
return
}
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_input(%arg0: si32) {
+ // expected-error at +1 {{operand #0 must be scalar or tensor of signless integer or quantized type}}
+ %0 = quant.scast %arg0 : si32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_result(%arg0: !qalias) {
+ // expected-error at +1 {{result #0 must be scalar or tensor of signless integer or quantized type}}
+ %0 = quant.scast %arg0 : !qalias to si32
+ return
+}
+
+// -----
+
+func.func @scast_both_integers(%arg0: i8) {
+ // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+ %0 = quant.scast %arg0 : i8 to i8
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_both_quantized(%arg0: !qalias) {
+ // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+ %0 = quant.scast %arg0 : !qalias to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_scalar_tensor(%arg0: tensor<i8>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.scast %arg0 : tensor<i8> to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor<i8>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.scast %arg0 : tensor<i8> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<?x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_integer_type_mismatch(%arg0: i32) {
+ // expected-error at +1 {{storage type in quantized type expected to match integer type}}
+ %0 = quant.scast %arg0 : i32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_scalar(%arg0: i8) {
+ // expected-error at +1 {{scalar types may not use per-axis quantization}}
+ %0 = quant.scast %arg0 : i8 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) {
+ // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+ %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) {
+ // expected-error at +1 {{quantized dimension size does not match number of scales}}
+ %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias>
+ return
+}
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
index ab3d6decfb248..4abc5830d081e 100644
--- a/mlir/test/Dialect/Quant/ops.mlir
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -94,3 +94,58 @@ func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) {
return
}
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_scalar(%arg0: i8) {
+ %0 = quant.scast %arg0 : i8 to !qalias
+ %1 = quant.scast %0 : !qalias to i8
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) {
+ %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias>
+ %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_unranked(%arg0: tensor<*xi8>) {
+ %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+ %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) {
+ %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias>
+ %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_dynamic(%arg0: tensor<?x?x?xi8>) {
+ %0 = quant.scast %arg0 : tensor<?x?x?xi8> to tensor<?x?x?x!qalias>
+ %1 = quant.scast %0 : tensor<?x?x?x!qalias> to tensor<?x?x?xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) {
+ %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+ %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+ return
+}
+
+
>From 764b1d5afaca8ac397d85157c54e2717fcefd0ac Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 24 Jul 2024 08:08:40 -0400
Subject: [PATCH 16/18] Canonicalization patterns for all ops and unit tests
---
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 2 +
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 67 ++++++---
mlir/test/Dialect/Quant/canonicalize.mlir | 134 +++++++++++++++---
3 files changed, 164 insertions(+), 39 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 5dab02e8e1ee5..036940119b349 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -48,6 +48,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [
let results = (outs quant_FloatScalarOrTensor:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
let hasVerifier = 1;
+ let hasFolder = 1;
let extraClassDeclaration = [{
/// Return the float type of the scalar or tensor result.
FloatType getFloatType();
@@ -87,6 +88,7 @@ def quant_QuantizeCastOp : quant_Op<"qcast", [
let results = (outs quant_QuantizedScalarOrTensor:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
let hasVerifier = 1;
+ let hasFolder = 1;
let extraClassDeclaration = [{
/// Return the float type of the scalar or tensor input.
FloatType getFloatType();
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index f722eb8e30806..6a709488ce01c 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -117,6 +117,17 @@ LogicalResult DequantizeCastOp::verify() {
getInput().getType());
}
+OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
+ // with the value of x. Values x and y are guaranteed to be of the same type
+ // in this pattern.
+ auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
+ if (!srcQcastOp)
+ return {};
+ assert(srcQcastOp.getInput().getType() == getType());
+ return srcQcastOp.getInput();
+}
+
FloatType DequantizeCastOp::getFloatType() {
return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
}
@@ -135,6 +146,18 @@ LogicalResult QuantizeCastOp::verify() {
getInput().getType());
}
+OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
+ // with the value of x if the casts invert each other. Contrary to the folding
+ // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
+ // x and y are not guaranteed to be of the same type here, as they may use
+ // different quantization parameters.
+ auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
+ if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
+ return {};
+ return srcDcastOp.getInput();
+}
+
FloatType QuantizeCastOp::getFloatType() {
return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
}
@@ -148,6 +171,28 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
// StorageCastOp
//===----------------------------------------------------------------------===//
+LogicalResult StorageCastOp::verify() {
+ auto quantizedType = getQuantizedType();
+ auto integerType = getIntegerType();
+ if (quantizedType.getStorageType() != integerType)
+ return emitError(
+ "storage type in quantized type expected to match integer type");
+
+ // Verify integrity of per-axis quantization information, if available. While
+ // the quantization type may appear in the input or the result, their tensor
+ // shapes are guaranteed to be identical at this point.
+ return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
+}
+
+OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.scast -> quant.scast -> y, replacing the second
+ // quant.scast with the value of x if the casts invert each other.
+ auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+ if (!srcScastOp || srcScastOp.getInput().getType() != getType())
+ return {};
+ return srcScastOp.getInput();
+}
+
IntegerType StorageCastOp::getIntegerType() {
auto inputScalarType = getElementTypeOrSelf(getInput().getType());
if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
@@ -166,28 +211,6 @@ QuantizedType StorageCastOp::getQuantizedType() {
return cast<QuantizedType>(resultScalarType);
}
-OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
- // Matches x -> [scast -> scast] -> y, replacing the second scast with the
- // value of x if the casts invert each other.
- auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
- if (!srcScastOp || srcScastOp.getInput().getType() != getType())
- return OpFoldResult();
- return srcScastOp.getInput();
-}
-
-LogicalResult StorageCastOp::verify() {
- auto quantizedType = getQuantizedType();
- auto integerType = getIntegerType();
- if (quantizedType.getStorageType() != integerType)
- return emitError(
- "storage type in quantized type expected to match integer type");
-
- // Verify integrity of per-axis quantization information, if available. While
- // the quantization type may appear in the input or the result, their tensor
- // shapes are guaranteed to be identical at this point.
- return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
-}
-
} // namespace quant
} // namespace mlir
diff --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir
index 36c3eaf5e10d2..73c57e2a48212 100644
--- a/mlir/test/Dialect/Quant/canonicalize.mlir
+++ b/mlir/test/Dialect/Quant/canonicalize.mlir
@@ -1,24 +1,124 @@
// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
+// CHECK-LABEL: @dcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+ %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
+
// -----
-// CHECK-LABEL: redundant_scast
-func.func @redundant_scast() -> tensor<4xi8> {
- // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8>
- // CHECK-NEXT: return
- %cst = arith.constant dense<5> : tensor<4xi8>
- %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
- %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<4xi8>
- %3 = arith.addi %2, %2 : tensor<4xi8>
- return %3 : tensor<4xi8>
+
+// CHECK-LABEL: @dcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> {
+ %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias>
+ %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+ return %1 : tensor<4xf32>
}
// -----
-// CHECK-LABEL: non_redundant_scast
-func.func @non_redundant_scast() -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>> {
- // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8>
- // CHECK-NEXT: scast
- // CHECK-NEXT: return
- %cst = arith.constant dense<5> : tensor<4xi8>
- %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
- return %1 : tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
+
+// CHECK-LABEL: @qcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+ %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+ %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+ return %1 : tensor<4x!qalias>
}
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> {
+ %0 = arith.negf %arg0 : tensor<4xf32>
+ %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+ return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+ %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+ %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1>
+ return %1 : tensor<4x!qalias1>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+ %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+ %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias>
+ return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]]
+// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]]
+// CHECK: return %[[SCAST]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> {
+ %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+ %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+ %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+ %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1>
+ return %1 : tensor<4x!qalias1>
+}
+
>From 70c5af961b01b91b7d1664ebad7fa528a37b35af Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 25 Jul 2024 00:23:28 -0400
Subject: [PATCH 17/18] Pass 'strip-func-quant-types'
---
.../mlir/Dialect/Quant/Transforms/Passes.td | 15 +++
.../Dialect/Quant/Transforms/CMakeLists.txt | 1 +
.../Quant/Transforms/StripFuncQuantTypes.cpp | 114 ++++++++++++++++++
.../Dialect/Quant/strip-func-quant-types.mlir | 88 ++++++++++++++
4 files changed, 218 insertions(+)
create mode 100644 mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
create mode 100644 mlir/test/Dialect/Quant/strip-func-quant-types.mlir
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index 56e10688b0c98..b25296d4db5a9 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -31,4 +31,19 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
];
}
+def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
+ let summary = "Strip quantized types from function headers";
+ let description = [{
+ Identify occurrences of function arguments using a quantized type and
+ replace them with a new value of the corresponding storage (signless
+ integer) type. For each converted argument, a `quant.scast` op is introduced
+ at the head of the function's entry block converting the new integer
+ argument into the original quantized value.
+ }];
+ let dependentDialects = [
+ "func::FuncDialect",
+ "quant::QuantDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 2daea7750cfe3..662c3e368b624 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRQuantTransforms
LowerQuantOps.cpp
+ StripFuncQuantTypes.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
new file mode 100644
index 0000000000000..8996eff61a39c
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -0,0 +1,114 @@
+//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Strips quantized types from function headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+class QuantizedTypeConverter : public TypeConverter {
+
+ static Type convertQuantizedType(QuantizedType quantizedType) {
+ return quantizedType.getStorageType();
+ }
+
+ static Type convertTensorType(TensorType tensorType) {
+ if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+ return tensorType.clone(convertQuantizedType(quantizedType));
+ return tensorType;
+ }
+
+ static Value materializeConversion(OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) {
+ assert(inputs.size() == 1);
+ return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+ }
+
+public:
+
+ explicit QuantizedTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertQuantizedType);
+ addConversion(convertTensorType);
+
+ addArgumentMaterialization(materializeConversion);
+ addSourceMaterialization(materializeConversion);
+ addTargetMaterialization(materializeConversion);
+ }
+};
+
+// Conversion pass
+class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+
+ // Return whether a type is considered legal when occurring in the header of
+ // a function or as an operand to a 'return' op.
+ static bool isLegalType(Type type) {
+ if (auto tensorType = dyn_cast<TensorType>(type))
+ return isLegalType(tensorType.getElementType());
+ return !isa<quant::QuantizedType>(type);
+ }
+
+public:
+
+ void runOnOperation() override {
+
+ auto moduleOp = cast<ModuleOp>(getOperation());
+ auto* context = &getContext();
+
+ QuantizedTypeConverter typeConverter;
+ ConversionTarget target(*context);
+ RewritePatternSet patterns(context);
+
+ // Mark func.func, func.return, and func.call illegal if they contain any
+ // quantized types.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
+ target.addDynamicallyLegalOp<func::CallOp>(
+ [&](func::CallOp op) { return typeConverter.isLegal(op); });
+
+ // Register conversion patterns
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+
+ // Apply conversion
+ if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
+
diff --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
new file mode 100644
index 0000000000000..e5f0d4921bed3
--- /dev/null
+++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @strip_operands
+// CHECK-SAME: %[[ARG_0:.*]]: i8
+// CHECK-SAME: %[[ARG_1:.*]]: i16
+// CHECK-SAME: %[[ARG_2:.*]]: f32
+
+// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}>
+// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}>
+
+// CHECK: "test.custom_op"(%[[ARG_1_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_0_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_2]])
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) {
+ "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias>
+ "test.custom_op"(%arg1) : (!qalias1) -> tensor<?x!qalias1>
+ "test.custom_op"(%arg2) : (f32) -> tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_results
+// CHECK-SAME: tensor<4xi8>, tensor<?xi16>, tensor<*xi8>, tensor<4xf32>
+
+// CHECK: %[[RESULT_0:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[RESULT_1:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULT_2:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8>
+
+// CHECK: %[[RESULT_3:.*]] = "test.custom_op"()
+
+// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]]
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_results() -> (tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>) {
+ %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+ %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+ %2 = "test.custom_op"() : () -> tensor<*x!qalias>
+ %3 = "test.custom_op"() : () -> tensor<4xf32>
+ return %0, %1, %2, %3 : tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @callee
+// CHECK-SAME: (tensor<4xi8>, tensor<?xi16>) -> (tensor<*xi8>, tensor<4xf32>)
+
+// CHECK-LABEL: @strip_call
+
+// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]])
+
+// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>>
+// CHECK: "test.custom_op"(%[[RESULT_0_CAST]])
+
+// CHECK: "test.custom_op"(%[[RESULTS]]#1)
+
+// CHECK: return
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func private @callee(tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+
+func.func @strip_call() {
+ %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+ %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+ %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+ "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> ()
+ "test.custom_op"(%2#1) : (tensor<4xf32>) -> ()
+ return
+}
>From cce8171c6d016d823e514ec304f94d2e8c4085c0 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 25 Jul 2024 18:43:30 -0400
Subject: [PATCH 18/18] Dialect documentation progress
---
.../mlir/Dialect/Quant/IR/QuantBase.td | 53 +++++-
.../include/mlir/Dialect/Quant/IR/QuantOps.td | 178 +++++++++++++-----
2 files changed, 186 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index d81838db3dc1a..2690b4fe0b111 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -17,8 +17,59 @@ include "mlir/IR/OpBase.td"
def Quant_Dialect : Dialect {
let name = "quant";
+ let description = [{
+ ## Per-axis quantization integrity
+
+ When type `!quant.uniform` contains per-axis quantization information, the
+ rules below are enforced. These rules guarantee that the quantization
+ information encoded in the data type is applicable to the context in which
+ the quantized type is used. For efficiency, these rules are actively
+ enforced by the verifiers of `quant` dialect ops, but they must be
+ respected in any context in which the `!quant.uniform` data type is used,
+ such as the header of a `func.func` op, or the input of an arithmetic
+ operation.
+
+ - A quantized type with per-channel quantization information must be the
+ element type of a tensor container type, and may not occur directly as
+ the data type of a scalar value.
+
+ ```
+ // Incorrect. Type !quant.uniform specifies per-channel quantization for a
+ // scalar type.
+ %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>
+
+ // Correct. Type `!quant.uniform` with per-channel quantization is wrapped in
+ // a `tensor` type.
+ %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
+ ```
+
+ - If the tensor containing the `!quant.uniform` type is ranked, its rank
+ must be greater than the channel axis specified in the quantized type.
+
+ ```
+ // Incorrect. The tensor rank (2) is not greater than the channel axis in the
+ // quantized type (3).
+ %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>
+
+ // Correct. The tensor rank (2) is now greater than the channel axis (1):
+ %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
+ ```
+
+ - If the axis dimension in the containing tensor is static, its size must
+ be equal to the number of scales present in the quantized type.
+
+ ```
+ // Incorrect. The channel axis is 1, and the size of dimension 1 in the
+ // containing tensor is 3. However, there are 4 scale values present in the
+ // quantized type.
+ %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
+
+ // Correct. The quantized type now includes 3 scale values, matching the size
+ // of dimension 1 of the result tensor.
+ %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+ ```
+ }];
let cppNamespace = "::mlir::quant";
-
let useDefaultTypePrinterParser = 1;
}
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
index 036940119b349..52dfc6b051de7 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -31,18 +31,55 @@ class quant_Op<string mnemonic, list<Trait> traits> :
def quant_DequantizeCastOp : quant_Op<"dcast", [
Pure,
quant_SameScalarOrTensorShape]> {
- let summary = "convert back from a quantized to quantizable (expressed) type operation";
+ let summary = "Dequantize cast operation";
let description = [{
- A DequantizeCast op `dcast` represents the inverse of a `qcast`,
- converting back from a quantized to quantizable (expressed) type.
+ Convert an input quantized value into its expressed floating-point value.
+ The dequantization process consists of the following steps:
- Like `qcast`s, a `dcast` is allowed to have both its operand and result
- as non quantized types. This facilitates transformations and marks edges
- where the computation must be carried out in the expressed type.
+ ```
+ def dequantize(quantizedValue: quantizedType) -> expressedType:
+ storedValue = reinterpretCast(quantizedValue, storageType)
+ storedValueFloat = convertIntToFloat(storedValue, expressedType)
+ zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+ expressedValue = (storedValueFloat - zeroPointFloat) * scale
+ return expressedValue
+ ```
+
+ Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained
+ from the corresponding parameters encoded in `quantizedType`. For
+ per-channel quantization, the appropriate `scale` and `zeroPoint` values
+ are used for each tensor element computation according to the channel the
+ element belongs to.
+
+ The operation must satisfy the following syntactic constraints:
+
+ - Operand `input` must be a scalar or tensor of type `!quant.uniform`.
+
+ - The result type must be a floating-point scalar or tensor.
+
+ - The `expressedType` parameter of the `!quant.uniform` type of the input
+ must match the floating-point type of the result.
+
+ - The operand and result types must be both scalars or both tensors. If
+ tensors, they must be both ranked or both unranked. If ranked, both must
+ have the same shape, including matching static and dynamic dimensions.
+
+ - If the operand uses per-channel quantization, its `!quant.uniform` type
+ must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
+
+ ```
+ // Dequantize a scalar quantized value
+ %result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
+
+ // Dequantize a dynamically shaped tensor of quantized values
+ %result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
- Especially early in transformation, it is common to have `dcast`s on
- all operands to ops that must operate with the expressed type (typically
- math ops prior to lowering to target-specific, quantized kernels).
+ // Dequantize an unranked tensor using per-axis quantization information
+ %result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
+ ```
}];
let arguments = (ins quant_QuantizedScalarOrTensor:$input);
let results = (outs quant_FloatScalarOrTensor:$result);
@@ -61,28 +98,57 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [
def quant_QuantizeCastOp : quant_Op<"qcast", [
Pure,
quant_SameScalarOrTensorShape]> {
- let summary = "convert a quantizable type to a quantized type";
+ let summary = "Quantize cast operation";
let description = [{
- A QuantizeCast `qcast` represents a potential type shift from a quantizable
- type to a quantized type.
-
- At runtime, a `qcast` will apply the transformation expressed by its
- operand and result type. For flexibility during transformation, it is also
- possible to have a `qcast` that performs no transformation (both its
- operand and result type are quantizable).
-
- A `qcast` will typically originate from either:
- a) An expressed or implied constraint in the source dialect which signals
- that a certain level of quantization is possible or required.
- b) An inference made by a quantization algorithm indicating that a
- quantized representation may be acceptable.
-
- Especially early in transformation, it is common to have pairs of
- `qcast` and `dcast` at points where a transition to a quantized type is
- required. In addition, it is also common to have an identity `qcast`
- (where the operand and result type are not quantized) at all points where
- it is legal to use a quantized representation (but is not known to be
- acceptable).
+ Convert a floating-point value to a quantized type. The quantization
+ process consists of the following steps:
+
+ ```
+ def quantize(expressedValue: expressedType) -> quantizedType:
+ zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+ scaledValue = expressedValue / scale
+ storedValueFloat = scaledValue + zeroPointFloat
+ storedValue = convertFloatToInt(storedValueFloat, storageType)
+ storedValueClamped = clamp(storedValue, storageMin, storageMax)
+ quantizedValue = reinterpretCast(storedValueClamped, quantizedType)
+ return quantizedValue
+ ```
+
+ Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`,
+ and `zeroPoint` are obtained from the corresponding parameters encoded in
+ `quantizedType`. For per-channel quantization, the appropriate `scale` and
+ `zeroPoint` values are used for each tensor element computation according
+ to the channel the element belongs to.
+
+ The operation must satisfy the following syntactic constraints:
+
+ - Operand `input` must be a floating-point scalar or tensor.
+
+ - The result type must be a scalar or tensor of type `!quant.uniform`.
+
+ - The `expressedType` parameter in the `!quant.uniform` type of the result
+ must match the floating-point type of the input.
+
+ - The operand and result types must be both scalars or both tensors. If
+ tensors, they must be both ranked or both unranked. If ranked, both must
+ have the same shape, including matching static and dynamic dimensions.
+
+ - If the result uses per-channel quantization, its `!quant.uniform` type
+ must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
+
+ ```
+ // Quantize a scalar floating-point value
+ %result = quant.qcast %input : f32 to !quant.uniform<i8:f32, 2.0>
+
+ // Quantize a dynamically shaped tensor of quantized values
+ %result = quant.qcast %input : tensor<?xf32> to tensor<?x!quant.uniform<i8:f32, 2.0>>
+
+ // Quantize an unranked tensor using per-axis quantization information
+ %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
+ ```
}];
let arguments = (ins quant_FloatScalarOrTensor:$input);
let results = (outs quant_QuantizedScalarOrTensor:$result);
@@ -102,24 +168,48 @@ def quant_StorageCastOp : quant_Op<"scast", [
Pure,
quant_SameScalarOrTensorShape,
quant_IntegerAndQuantizedCombination]> {
- let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
+ let summary = "Storage cast operation";
let description = [{
- A StorageCast `scast` represents a cast from or to a type based on the
- storage type and a type based on a corresponding quantized type.
+ Convert a value from a quantized type to the corresponding signless integer
+ storage type, or vice versa. This conversion simply involves a
+ reinterpretation of the input bits and does not involve any data
+ manipulation.
- This op exists to ensure type coherency for between parts of the computation
- which are operating directly on an underlying storage type and those which
- operate on quantized values.
+ The following syntactic restrictions must be met:
+
+ - Operand `input` must be a scalar or tensor of a signless integer or
+ `!quant.uniform` type.
+
+ - The result must be a scalar or tensor of a signless integer or
+ `!quant.uniform` type.
+
+ - If the operand is a scalar or tensor of type integer, the result must be
+ a scalar or tensor of type `!quant.uniform`, and vice versa.
+
+ - The operand and result must be both scalars or both tensors. If tensors,
+ they must be both ranked or both unranked. If ranked, both must have the
+ same shape, including matching static and dynamic dimensions.
+
+ - The width of the `storageType` parameter of the quantized type of the
+ operand or result must match the width of the signless integer type of
+ the operand or result.
+
+ - If the operand or result uses per-channel quantization, its
+ `!quant.uniform` type must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
- Examples from storage to quantized type:
- ```
- i8 -> !quant<"uniform[i8:f32]{1.0}">
- ```
- ```
- tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- ```
```
- vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+ // Cast a scalar quantized value into its storage type
+ %result = quant.scast %input : !quant.uniform<i8:f32, 2.0> to i8
+
+ // Cast a dynamically shaped tensor of quantized values into their storage type
+ %result = quant.scast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xi8>
+
+ // Cast an unranked tensor of signless integers into a quantized type using
+ // per-channel quantization
+ %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
```
}];
let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
More information about the Mlir-commits
mailing list