[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)

Rafael Ubal llvmlistbot at llvm.org
Tue Jul 30 09:28:22 PDT 2024


================
@@ -0,0 +1,114 @@
+//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Strips quantized types from function headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+class QuantizedTypeConverter : public TypeConverter {
+
+  static Type convertQuantizedType(QuantizedType quantizedType) {
+    return quantizedType.getStorageType();
+  }
+  
+  static Type convertTensorType(TensorType tensorType) {
+    if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+      return tensorType.clone(convertQuantizedType(quantizedType));
+    return tensorType;
+  }
+
+  static Value materializeConversion(OpBuilder &builder, Type type,
+                                     ValueRange inputs, Location loc) {
+    assert(inputs.size() == 1);
+    return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+  }
+
+public:
+
+  explicit QuantizedTypeConverter() {
+    addConversion([](Type type) { return type; });
+    addConversion(convertQuantizedType);
+    addConversion(convertTensorType);
+
+    addArgumentMaterialization(materializeConversion);
+    addSourceMaterialization(materializeConversion);
+    addTargetMaterialization(materializeConversion);
+  }
+};
+
+// Conversion pass
+class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+
+  // Return whether a type is considered legal when occurring in the header of
+  // a function or as an operand to a 'return' op.
+  static bool isLegalType(Type type) {
+    if (auto tensorType = dyn_cast<TensorType>(type))
+      return isLegalType(tensorType.getElementType());
+    return !isa<quant::QuantizedType>(type);
+  }
+
+public:
+
+  void runOnOperation() override {
+    
+    auto moduleOp = cast<ModuleOp>(getOperation());
+    auto* context = &getContext();
+
+    QuantizedTypeConverter typeConverter;
+    ConversionTarget target(*context);
+    RewritePatternSet patterns(context);
+
+    // Mark func.func, func.return, and func.call illegal if they contain any
+    // quantized types.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>(
+        [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
+    target.addDynamicallyLegalOp<func::CallOp>(
----------------
rafaelubalmw wrote:

Adding support for `func.call_indirect` seems a bit unclear to me in this context. I'm considering this code:

```
func.func @f(%arg0: !qalias, %arg1: (!qalias) -> i1) {
  %val = func.call_indirect %arg1(%arg0) : (!qalias) -> i1
  return
}
```

A prerequisite step would involve converting quantized types within function types beyond just `func.func` signatures -- here, the function type used by `%arg1`. I'm particularly unclear on the details of this substitution process. Block arguments only? Any value, even if generated by unregistered ops? How do we guarantee well-formed ops then?

After that, we would emit the appropriate `scast` ops for the `func.call_indirect` as we do for `func.call`. That would likely involve adding a new pattern in `mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp` for `func.call_indirect`, which doesn't currently exist. I'm a bit reluctant to disturb more files than strictly necessary outside of the `quant` dialect in this PR.

Maybe worth addressing in the future if the need arises?

https://github.com/llvm/llvm-project/pull/100667


More information about the Mlir-commits mailing list