[Mlir-commits] [mlir] fa8eb27 - [mlir][arith] Add wide integer emulation pass
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 8 10:53:16 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-08T13:51:01-04:00
New Revision: fa8eb2708814a406261588fafe922047095b0db0
URL: https://github.com/llvm/llvm-project/commit/fa8eb2708814a406261588fafe922047095b0db0
DIFF: https://github.com/llvm/llvm-project/commit/fa8eb2708814a406261588fafe922047095b0db0.diff
LOG: [mlir][arith] Add wide integer emulation pass
In this first patch in a series to add wide integer emulation:
* Set up the initial pass structure
* Add a custom type converter
* Handle func ops
The initial implementation supports power-of-two integers types only. We
emulate wide integer operations by splitting original i2N integer types
into two iN halves
My immediate use case is to emulate i64 operations using i32 ones
on mobile GPUs that do not support i64.
Reviewed By: antiagainst, Mogball
Differential Revision: https://reviews.llvm.org/D133135
Added:
mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
Modified:
mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
index 922d653decbbe..5ee3fb0d72367 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
@@ -15,16 +15,25 @@ namespace mlir {
namespace arith {
#define GEN_PASS_DECL_ARITHMETICBUFFERIZE
+#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT
#define GEN_PASS_DECL_ARITHMETICEXPANDOPS
#define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+class WideIntEmulationConverter;
+
/// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass();
/// Create a pass to bufferize arith.constant ops.
std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
+/// Adds patterns to emulate wide Arithmetic and Function ops over integer
+/// types into supported ones. This is done by splitting original power-of-two
+/// i2N integer types into two iN halves.
+void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
/// Add patterns to expand Arithmetic ops for LLVM lowering.
void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
index 752d715087959..3895562079d70 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
@@ -49,4 +49,24 @@ def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
}
+def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> {
+ let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+ let description = [{
+ Emulate integer operations that use too wide integer types with equivalent
+ operations on supported narrow integer types. This is done by splitting
+ original integer values into two halves.
+
+ This pass is intended preserve semantics but not necessarily provide the
+ most efficient implementation.
+ TODO: Optimize op emulation.
+
+ Currently, only power-of-two integer bitwidths are supported.
+ }];
+ let options = [
+ Option<"widestIntSupported", "widest-int-supported", "unsigned",
+ /*default=*/"32", "Widest integer type supported by the target">,
+ ];
+ let dependentDialects = ["vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
new file mode 100644
index 0000000000000..814db23e43402
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
@@ -0,0 +1,34 @@
+//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts integer types that are too wide for the target by splitting them in
+/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N
+/// is the widest integer bitwidth supported by the target.
+/// Currently, we only handle power-of-two integer types and support conversions
+/// of integers twice as wide as the maxium supported by the target. Wide
+/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the
+/// first element is the low half of the original integer, and the second
+/// element the high half.
+class WideIntEmulationConverter : public TypeConverter {
+public:
+ explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget);
+
+ unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; }
+
+private:
+ unsigned maxIntWidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index f140715e603ee..0b56659287738 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ EmulateWideInt.cpp
ExpandOps.cpp
UnsignedWhenEquivalent.cpp
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
new file mode 100644
index 0000000000000..94e321ba1ad76
--- /dev/null
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -0,0 +1,120 @@
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+
+namespace {
+struct EmulateWideIntPass final
+ : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
+ using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(widestIntSupported)) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+ return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+ arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // end anonymous namespace
+
+arith::WideIntEmulationConverter::WideIntEmulationConverter(
+ unsigned widestIntSupportedByTarget)
+ : maxIntWidth(widestIntSupportedByTarget) {
+ assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
+ "Only power-of-two integers are supported");
+
+ // Scalar case.
+ addConversion([this](IntegerType ty) -> Optional<Type> {
+ unsigned width = ty.getWidth();
+ if (width <= maxIntWidth)
+ return ty;
+
+ // i2N --> vector<2xiN>
+ if (width == 2 * maxIntWidth)
+ return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
+
+ return None;
+ });
+
+ // Vector case.
+ addConversion([this](VectorType ty) -> Optional<Type> {
+ auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ if (width <= maxIntWidth)
+ return ty;
+
+ // vector<...xi2N> --> vector<...x2xiN>
+ if (width == 2 * maxIntWidth) {
+ auto newShape = to_vector(ty.getShape());
+ newShape.push_back(2);
+ return VectorType::get(newShape,
+ IntegerType::get(ty.getContext(), maxIntWidth));
+ }
+
+ return None;
+ });
+
+ // Function case.
+ addConversion([this](FunctionType ty) -> Optional<Type> {
+ // Convert inputs and results, e.g.:
+ // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
+ SmallVector<Type> inputs;
+ if (failed(convertTypes(ty.getInputs(), inputs)))
+ return None;
+
+ SmallVector<Type> results;
+ if (failed(convertTypes(ty.getResults(), results)))
+ return None;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+}
+
+void arith::populateWideIntEmulationPatterns(
+ WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+ // Populate `func.*` conversion patterns.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
new file mode 100644
index 0000000000000..aafeb5b3a1cdd
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_i32
+// CHECK-SAME: ([[ARG:%.+]]: i32) -> i32
+// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32
+// CHECK-NEXT: return [[X]] : i32
+func.func @addi_same_i32(%a : i32) -> i32 {
+ %x = arith.addi %a, %a : i32
+ return %x : i32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_vector_i32
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32>
+// CHECK-NEXT: return [[X]] : vector<2xi32>
+func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> {
+ %x = arith.addi %a, %a : vector<2xi32>
+ return %x : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<2xi32>
+func.func @identity_scalar(%x : i64) -> i64 {
+ return %x : i64
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<4x2xi32>
+func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> {
+ return %x : vector<4xi64>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<3x4x2xi32>
+func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> {
+ return %x : vector<3x4xi64>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: return [[RES]] : vector<4x2xi32>
+func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
+ %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
+ return %res : vector<4xi64>
+}
More information about the Mlir-commits
mailing list