[Mlir-commits] [mlir] 7fa1d74 - Reland "[mlir][arith] Add wide integer emulation pass"

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 8 20:37:33 PDT 2022

Author: Jakub Kuderski
Date: 2022-09-08T23:30:47-04:00
New Revision: 7fa1d743d073b4af6acb0a34b6324edf1d92f518

URL: https://github.com/llvm/llvm-project/commit/7fa1d743d073b4af6acb0a34b6324edf1d92f518
DIFF: https://github.com/llvm/llvm-project/commit/7fa1d743d073b4af6acb0a34b6324edf1d92f518.diff

LOG: Reland "[mlir][arith] Add wide integer emulation pass"

This reverts commit 45b5e8abe56d7f28c88b0c6cdd60ff741874fb1d.

Relands https://reviews.llvm.org/D133135 after fixing shared libs




diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
index 922d653decbb..5ee3fb0d7236 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
@@ -15,16 +15,25 @@ namespace mlir {
 namespace arith {
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+class WideIntEmulationConverter;
 /// Create a pass to bufferize Arithmetic ops.
 std::unique_ptr<Pass> createArithmeticBufferizePass();
 /// Create a pass to bufferize arith.constant ops.
 std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
+/// Adds patterns to emulate wide Arithmetic and Function ops over integer
+/// types into supported ones. This is done by splitting original power-of-two
+/// i2N integer types into two iN halves.
+void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
+                                      RewritePatternSet &patterns);
 /// Add patterns to expand Arithmetic ops for LLVM lowering.
 void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
index 752d71508795..3895562079d7 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
@@ -49,4 +49,24 @@ def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
   let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
+def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> {
+  let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+  let description = [{
+    Emulate integer operations that use too wide integer types with equivalent
+    operations on supported narrow integer types. This is done by splitting
+    original integer values into two halves.
+    This pass is intended preserve semantics but not necessarily provide the
+    most efficient implementation.
+    TODO: Optimize op emulation.
+    Currently, only power-of-two integer bitwidths are supported.
+  }];
+  let options = [
+    Option<"widestIntSupported", "widest-int-supported", "unsigned",
+           /*default=*/"32", "Widest integer type supported by the target">,
+  ];
+  let dependentDialects = ["vector::VectorDialect"];

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
new file mode 100644
index 000000000000..814db23e4340
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
@@ -0,0 +1,34 @@
+//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::arith {
+/// Converts integer types that are too wide for the target by splitting them in
+/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N
+/// is the widest integer bitwidth supported by the target.
+/// Currently, we only handle power-of-two integer types and support conversions
+/// of integers twice as wide as the maxium supported by the target. Wide
+/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the
+/// first element is the low half of the original integer, and the second
+/// element the high half.
+class WideIntEmulationConverter : public TypeConverter {
+  explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget);
+  unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; }
+  unsigned maxIntWidth;
+} // namespace mlir::arith

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index f140715e603e..ba68d36de0e9 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
+  EmulateWideInt.cpp
@@ -15,9 +16,13 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
+  MLIRFuncDialect
+  MLIRFuncTransforms
+  MLIRTransformUtils
+  MLIRVectorDialect

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
new file mode 100644
index 000000000000..94e321ba1ad7
--- /dev/null
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -0,0 +1,120 @@
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+namespace mlir::arith {
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+using namespace mlir;
+namespace {
+struct EmulateWideIntPass final
+    : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
+  using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
+  void runOnOperation() override {
+    if (!llvm::isPowerOf2_32(widestIntSupported)) {
+      signalPassFailure();
+      return;
+    }
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+      return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+    RewritePatternSet patterns(ctx);
+    arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+} // end anonymous namespace
+    unsigned widestIntSupportedByTarget)
+    : maxIntWidth(widestIntSupportedByTarget) {
+  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
+         "Only power-of-two integers are supported");
+  // Scalar case.
+  addConversion([this](IntegerType ty) -> Optional<Type> {
+    unsigned width = ty.getWidth();
+    if (width <= maxIntWidth)
+      return ty;
+    // i2N --> vector<2xiN>
+    if (width == 2 * maxIntWidth)
+      return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
+    return None;
+  });
+  // Vector case.
+  addConversion([this](VectorType ty) -> Optional<Type> {
+    auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+    if (!intTy)
+      return ty;
+    unsigned width = intTy.getWidth();
+    if (width <= maxIntWidth)
+      return ty;
+    // vector<...xi2N> --> vector<...x2xiN>
+    if (width == 2 * maxIntWidth) {
+      auto newShape = to_vector(ty.getShape());
+      newShape.push_back(2);
+      return VectorType::get(newShape,
+                             IntegerType::get(ty.getContext(), maxIntWidth));
+    }
+    return None;
+  });
+  // Function case.
+  addConversion([this](FunctionType ty) -> Optional<Type> {
+    // Convert inputs and results, e.g.:
+    //   (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
+    SmallVector<Type> inputs;
+    if (failed(convertTypes(ty.getInputs(), inputs)))
+      return None;
+    SmallVector<Type> results;
+    if (failed(convertTypes(ty.getResults(), results)))
+      return None;
+    return FunctionType::get(ty.getContext(), inputs, results);
+  });
+void arith::populateWideIntEmulationPatterns(
+    WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+  // Populate `func.*` conversion patterns.
+  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                 typeConverter);
+  populateCallOpTypeConversionPattern(patterns, typeConverter);
+  populateReturnOpTypeConversionPattern(patterns, typeConverter);

diff  --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
new file mode 100644
index 000000000000..aafeb5b3a1cd
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_i32
+// CHECK-SAME:    ([[ARG:%.+]]: i32) -> i32
+// CHECK-NEXT:    [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32
+// CHECK-NEXT:    return [[X]] : i32
+func.func @addi_same_i32(%a : i32) -> i32 {
+    %x = arith.addi %a, %a : i32
+    return %x : i32
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_vector_i32
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32>
+// CHECK-NEXT:    return [[X]] : vector<2xi32>
+func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> {
+    %x = arith.addi %a, %a : vector<2xi32>
+    return %x : vector<2xi32>
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME:     ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:     return [[ARG]] : vector<2xi32>
+func.func @identity_scalar(%x : i64) -> i64 {
+    return %x : i64
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:     return [[ARG]] : vector<4x2xi32>
+func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> {
+    return %x : vector<4xi64>
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME:     ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32>
+// CHECK-NEXT:     return [[ARG]] : vector<3x4x2xi32>
+func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> {
+    return %x : vector<3x4xi64>
+// CHECK-LABEL: func @call
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:     [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:     return [[RES]] : vector<4x2xi32>
+func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
+    %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
+    return %res : vector<4xi64>


More information about the Mlir-commits mailing list