[Mlir-commits] [mlir] 7fa1d74 - Reland "[mlir][arith] Add wide integer emulation pass"
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 8 20:37:33 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-08T23:30:47-04:00
New Revision: 7fa1d743d073b4af6acb0a34b6324edf1d92f518
URL: https://github.com/llvm/llvm-project/commit/7fa1d743d073b4af6acb0a34b6324edf1d92f518
DIFF: https://github.com/llvm/llvm-project/commit/7fa1d743d073b4af6acb0a34b6324edf1d92f518.diff
LOG: Reland "[mlir][arith] Add wide integer emulation pass"
This reverts commit 45b5e8abe56d7f28c88b0c6cdd60ff741874fb1d.
Relands https://reviews.llvm.org/D133135 after fixing shared libs
builds.
Added:
mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
Modified:
mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
index 922d653decbb..5ee3fb0d7236 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
@@ -15,16 +15,25 @@ namespace mlir {
namespace arith {
#define GEN_PASS_DECL_ARITHMETICBUFFERIZE
+#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT
#define GEN_PASS_DECL_ARITHMETICEXPANDOPS
#define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+class WideIntEmulationConverter;
+
/// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass();
/// Create a pass to bufferize arith.constant ops.
std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
+/// Adds patterns to emulate wide Arithmetic and Function ops over integer
+/// types into supported ones. This is done by splitting original power-of-two
+/// i2N integer types into two iN halves.
+void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
/// Add patterns to expand Arithmetic ops for LLVM lowering.
void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
index 752d71508795..3895562079d7 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
@@ -49,4 +49,24 @@ def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
}
+def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> {
+ let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+ let description = [{
+ Emulate integer operations that use too wide integer types with equivalent
+ operations on supported narrow integer types. This is done by splitting
+ original integer values into two halves.
+
+ This pass is intended preserve semantics but not necessarily provide the
+ most efficient implementation.
+ TODO: Optimize op emulation.
+
+ Currently, only power-of-two integer bitwidths are supported.
+ }];
+ let options = [
+ Option<"widestIntSupported", "widest-int-supported", "unsigned",
+ /*default=*/"32", "Widest integer type supported by the target">,
+ ];
+ let dependentDialects = ["vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
new file mode 100644
index 000000000000..814db23e4340
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
@@ -0,0 +1,34 @@
+//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts integer types that are too wide for the target by splitting them in
+/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N
+/// is the widest integer bitwidth supported by the target.
+/// Currently, we only handle power-of-two integer types and support conversions
+/// of integers twice as wide as the maxium supported by the target. Wide
+/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the
+/// first element is the low half of the original integer, and the second
+/// element the high half.
+class WideIntEmulationConverter : public TypeConverter {
+public:
+ explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget);
+
+ unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; }
+
+private:
+ unsigned maxIntWidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index f140715e603e..ba68d36de0e9 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ EmulateWideInt.cpp
ExpandOps.cpp
UnsignedWhenEquivalent.cpp
@@ -15,9 +16,13 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
+ MLIRFuncDialect
+ MLIRFuncTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRTransforms
+ MLIRTransformUtils
+ MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
new file mode 100644
index 000000000000..94e321ba1ad7
--- /dev/null
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -0,0 +1,120 @@
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+
+namespace {
+struct EmulateWideIntPass final
+ : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
+ using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(widestIntSupported)) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+ return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+ arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // end anonymous namespace
+
+arith::WideIntEmulationConverter::WideIntEmulationConverter(
+ unsigned widestIntSupportedByTarget)
+ : maxIntWidth(widestIntSupportedByTarget) {
+ assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
+ "Only power-of-two integers are supported");
+
+ // Scalar case.
+ addConversion([this](IntegerType ty) -> Optional<Type> {
+ unsigned width = ty.getWidth();
+ if (width <= maxIntWidth)
+ return ty;
+
+ // i2N --> vector<2xiN>
+ if (width == 2 * maxIntWidth)
+ return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
+
+ return None;
+ });
+
+ // Vector case.
+ addConversion([this](VectorType ty) -> Optional<Type> {
+ auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ if (width <= maxIntWidth)
+ return ty;
+
+ // vector<...xi2N> --> vector<...x2xiN>
+ if (width == 2 * maxIntWidth) {
+ auto newShape = to_vector(ty.getShape());
+ newShape.push_back(2);
+ return VectorType::get(newShape,
+ IntegerType::get(ty.getContext(), maxIntWidth));
+ }
+
+ return None;
+ });
+
+ // Function case.
+ addConversion([this](FunctionType ty) -> Optional<Type> {
+ // Convert inputs and results, e.g.:
+ // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
+ SmallVector<Type> inputs;
+ if (failed(convertTypes(ty.getInputs(), inputs)))
+ return None;
+
+ SmallVector<Type> results;
+ if (failed(convertTypes(ty.getResults(), results)))
+ return None;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+}
+
+void arith::populateWideIntEmulationPatterns(
+ WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+ // Populate `func.*` conversion patterns.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
new file mode 100644
index 000000000000..aafeb5b3a1cd
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_i32
+// CHECK-SAME: ([[ARG:%.+]]: i32) -> i32
+// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32
+// CHECK-NEXT: return [[X]] : i32
+func.func @addi_same_i32(%a : i32) -> i32 {
+ %x = arith.addi %a, %a : i32
+ return %x : i32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_vector_i32
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32>
+// CHECK-NEXT: return [[X]] : vector<2xi32>
+func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> {
+ %x = arith.addi %a, %a : vector<2xi32>
+ return %x : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<2xi32>
+func.func @identity_scalar(%x : i64) -> i64 {
+ return %x : i64
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<4x2xi32>
+func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> {
+ return %x : vector<4xi64>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<3x4x2xi32>
+func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> {
+ return %x : vector<3x4xi64>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: return [[RES]] : vector<4x2xi32>
+func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
+ %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
+ return %res : vector<4xi64>
+}
More information about the Mlir-commits
mailing list