[Mlir-commits] [mlir] 242d558 - [mlir][arith] Add test pass for wide integer emulation
Jakub Kuderski
llvmlistbot at llvm.org
Tue Sep 20 08:26:03 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-20T11:22:28-04:00
New Revision: 242d558658cd5a480b02883e2982d7246342e0d0
URL: https://github.com/llvm/llvm-project/commit/242d558658cd5a480b02883e2982d7246342e0d0
DIFF: https://github.com/llvm/llvm-project/commit/242d558658cd5a480b02883e2982d7246342e0d0.diff
LOG: [mlir][arith] Add test pass for wide integer emulation
The new test pass allows for running wide integer emulation conversion
within specified functions only.
I intend to use it in integration tests in a way that allows me print both
original and emulated results in the same format, or even compare both results
at runtime and print on mismatch only.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D134120
Added:
mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir
mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt
mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp
Modified:
mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
mlir/test/lib/Dialect/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir b/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir
new file mode 100644
index 0000000000000..bc6151e1d472f
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir
@@ -0,0 +1,39 @@
+// Check that the test version of the wide integer emulation pass applies
+// conversion to functions whose name start with a given prefix only, and that
+// the function signatures are preserved.
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="function-prefix=emulate_me_" | FileCheck %s
+
+// CHECK-LABEL: func.func @entry()
+// CHECK: {{%.+}} = call @emulate_me_please({{.+}}) : (i64) -> i64
+// CHECK-NEXT: {{%.+}} = call @foo({{.+}}) : (i64) -> i64
+func.func @entry() {
+ %cst0 = arith.constant 0 : i64
+ func.call @emulate_me_please(%cst0) : (i64) -> (i64)
+ func.call @foo(%cst0) : (i64) -> (i64)
+ return
+}
+
+// CHECK-LABEL: func.func @emulate_me_please
+// CHECK-SAME: ([[ARG:%.+]]: i64) -> i64 {
+// CHECK-NEXT: [[BCAST0:%.+]] = llvm.bitcast [[ARG]] : i64 to vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[BCAST0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[BCAST0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32>
+// CHECK-NEXT: {{%.+}}, {{%.+}} = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1
+// CHECK: [[RES:%.+]] = llvm.bitcast {{%.+}} : vector<2xi32> to i64
+// CHECK-NEXt: return [[RES]] : i64
+func.func @emulate_me_please(%x : i64) -> i64 {
+ %r = arith.addi %x, %x : i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func.func @foo
+// CHECK-SAME: ([[ARG:%.+]]: i64) -> i64 {
+// CHECK-NEXT: [[RES:%.+]] = arith.addi [[ARG]], [[ARG]] : i64
+// CHECK-NEXT: return [[RES]] : i64
+func.func @foo(%x : i64) -> i64 {
+ %r = arith.addi %x, %x : i64
+ return %r : i64
+}
diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
index 8cc5ceba06522..22ef5d4087459 100644
--- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
@@ -1,7 +1,7 @@
// Check that the wide integer constant emulation produces the same result as wide
// constants and that printing works. Emulate i16 ops with i8 ops.
-// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
@@ -9,6 +9,16 @@
// RUN: FileCheck %s --match-full-lines --check-prefix=EMULATED
func.func @entry() {
+ %cst0 = arith.constant 0 : i16
+ func.call @emulate_constant(%cst0) : (i16) -> ()
+ func.call @foo(%cst0) : (i16) -> ()
+ return
+}
+
+func.func @emulate_constant(%first : i16) {
+ // EMULATED: ( 0, 0 )
+ vector.print %first : i16
+
%cst0 = arith.constant 0 : i16
%cst1 = arith.constant 1 : i16
%cst_1 = arith.constant -1 : i16
@@ -20,7 +30,7 @@ func.func @entry() {
%cst_i16_max = arith.constant 32767 : i16
%cst_i16_min = arith.constant -32768 : i16
- // EMULATED: ( 0, 0 )
+ // EMULATED-NEXT: ( 0, 0 )
vector.print %cst0 : i16
// EMULATED-NEXT: ( 1, 0 )
vector.print %cst1 : i16
@@ -39,6 +49,17 @@ func.func @entry() {
vector.print %cst_i16_max : i16
// EMULATED-NEXT: ( 0, -128 )
vector.print %cst_i16_min : i16
+ return
+}
+func.func @foo(%first: i16) {
+ // These should not be emulated because the function name does not start with
+ // 'emulated_'.
+
+ // EMULATED-NEXT: 0
+ vector.print %first : i16
+ // EMULATED-NEXT: 1
+ %cst1 = arith.constant 1 : i16
+ vector.print %cst1 : i16
return
}
diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
index 7a56ed929d74b..976e28fe72ca8 100644
--- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
@@ -5,17 +5,23 @@
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s --match-full-lines --check-prefix=WIDE
+// RUN: FileCheck %s --match-full-lines
-// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s --match-full-lines --check-prefix=EMULATED
+// RUN: FileCheck %s --match-full-lines
-func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
+// Ops in this function *only* will be emulated using i8 types.
+func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) {
%res = arith.muli %lhs, %rhs : i16
+ return %res : i16
+}
+
+func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
+ %res = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16)
vector.print %res : i16
return
}
@@ -34,63 +40,45 @@ func.func @entry() {
%cst_i16_max = arith.constant 32767 : i16
%cst_i16_min = arith.constant -32768 : i16
- // WIDE: 0
- // EMULATED: ( 0, 0 )
+ // CHECK: 0
func.call @check_muli(%cst0, %cst0) : (i16, i16) -> ()
- // WIDE-NEXT: 0
- // EMULATED-NEXT: ( 0, 0 )
+ // CHECK-NEXT: 0
func.call @check_muli(%cst0, %cst1) : (i16, i16) -> ()
- // WIDE-NEXT: 1
- // EMULATED-NEXT: ( 1, 0 )
+ // CHECK-NEXT: 1
func.call @check_muli(%cst1, %cst1) : (i16, i16) -> ()
- // WIDE-NEXT: -1
- // EMULATED-NEXT: ( -1, -1 )
+ // CHECK-NEXT: -1
func.call @check_muli(%cst1, %cst_1) : (i16, i16) -> ()
- // WIDE-NEXT: 1
- // EMULATED-NEXT: ( 1, 0 )
+ // CHECK-NEXT: 1
func.call @check_muli(%cst_1, %cst_1) : (i16, i16) -> ()
- // WIDE-NEXT: -3
- // EMULATED-NEXT: ( -3, -1 )
+ // CHECK-NEXT: -3
func.call @check_muli(%cst1, %cst_3) : (i16, i16) -> ()
- // WIDE-NEXT: 169
- // EMULATED-NEXT: ( -87, 0 )
+ // CHECK-NEXT: 169
func.call @check_muli(%cst13, %cst13) : (i16, i16) -> ()
- // WIDE-NEXT: 481
- // EMULATED-NEXT: ( -31, 1 )
+ // CHECK-NEXT: 481
func.call @check_muli(%cst13, %cst37) : (i16, i16) -> ()
- // WIDE-NEXT: 1554
- // EMULATED-NEXT: ( 18, 6 )
+ // CHECK-NEXT: 1554
func.call @check_muli(%cst37, %cst42) : (i16, i16) -> ()
- // WIDE-NEXT: -256
- // EMULATED-NEXT: ( 0, -1 )
+ // CHECK-NEXT: -256
func.call @check_muli(%cst_1, %cst256) : (i16, i16) -> ()
- // WIDE-NEXT: 3328
- // EMULATED-NEXT: ( 0, 13 )
+ // CHECK-NEXT: 3328
func.call @check_muli(%cst256, %cst13) : (i16, i16) -> ()
- // WIDE-NEXT: 9472
- // EMULATED-NEXT: ( 0, 37 )
+ // CHECK-NEXT: 9472
func.call @check_muli(%cst256, %cst37) : (i16, i16) -> ()
- // WIDE-NEXT: -768
- // EMULATED-NEXT: ( 0, -3 )
+ // CHECK-NEXT: -768
func.call @check_muli(%cst256, %cst_3) : (i16, i16) -> ()
- // WIDE-NEXT: 32755
- // EMULATED-NEXT: ( -13, 127 )
+ // CHECK-NEXT: 32755
func.call @check_muli(%cst13, %cst_i16_max) : (i16, i16) -> ()
- // WIDE-NEXT: -32768
- // EMULATED-NEXT: ( 0, -128 )
+ // CHECK-NEXT: -32768
func.call @check_muli(%cst_i16_min, %cst37) : (i16, i16) -> ()
- // WIDE-NEXT: 1
- // EMULATED-NEXT: ( 1, 0 )
+ // CHECK-NEXT: 1
func.call @check_muli(%cst_i16_max, %cst_i16_max) : (i16, i16) -> ()
- // WIDE-NEXT: -32768
- // EMULATED-NEXT: ( 0, -128 )
+ // CHECK-NEXT: -32768
func.call @check_muli(%cst_i16_min, %cst13) : (i16, i16) -> ()
- // WIDE-NEXT: 0
- // EMULATED-NEXT: ( 0, 0 )
+ // CHECK-NEXT: 0
func.call @check_muli(%cst_i16_min, %cst_i16_min) : (i16, i16) -> ()
return
diff --git a/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt
new file mode 100644
index 0000000000000..17d288e7a2a5f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRArithmeticTestPasses
+ TestEmulateWideInt.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRArithmeticDialect
+ MLIRArithmeticTransforms
+ MLIRFuncDialect
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRVectorDialect
+)
diff --git a/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp
new file mode 100644
index 0000000000000..7cf76ad01502f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp
@@ -0,0 +1,95 @@
+//===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for integration testing of wide integer
+// emulation patterns. Applies conversion patterns only to functions whose
+// names start with a specified prefix.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct TestEmulateWideIntPass
+ : public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
+
+ TestEmulateWideIntPass() = default;
+ TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithmeticDialect, func::FuncDialect,
+ LLVM::LLVMDialect, vector::VectorDialect>();
+ }
+ StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
+ StringRef getDescription() const final {
+ return "Function pass to test Wide Integer Emulation";
+ }
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
+ signalPassFailure();
+ return;
+ }
+
+ func::FuncOp op = getOperation();
+ if (!op.getSymName().startswith(testFunctionPrefix))
+ return;
+
+ MLIRContext *ctx = op.getContext();
+ arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+
+ // Use `llvm.bitcast` as the bridge so that we can use preserve the
+ // function argument and return types of the processed function.
+ // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
+ // casts (and vice versa) and using it insted of `llvm.bitcast`.
+ auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Optional<Value> {
+ auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
+ return cast->getResult(0);
+ };
+ typeConverter.addSourceMaterialization(addBitcast);
+ typeConverter.addTargetMaterialization(addBitcast);
+
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
+ vector::VectorDialect>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+ arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+
+ Option<std::string> testFunctionPrefix{
+ *this, "function-prefix",
+ llvm::cl::desc("Prefix of functions to run the emulation pass on"),
+ llvm::cl::init("emulate_")};
+ Option<unsigned> widestIntSupported{
+ *this, "widest-int-supported",
+ llvm::cl::desc("Maximum integer bit width supported by the target"),
+ llvm::cl::init(32)};
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestArithmeticEmulateWideIntPass() {
+ PassRegistration<TestEmulateWideIntPass>();
+}
+} // namespace mlir::test
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 46b38dcfbc736..002e484f8195b 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Affine)
+add_subdirectory(Arithmetic)
add_subdirectory(DLTI)
add_subdirectory(Func)
add_subdirectory(GPU)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index e1a90956e122b..3b27cfda044d1 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -14,6 +14,7 @@ if(MLIR_INCLUDE_TESTS)
set(test_libs
MLIRTestFuncToLLVM
MLIRAffineTransformsTestPasses
+ MLIRArithmeticTestPasses
MLIRDLTITestPasses
MLIRFuncTestPasses
MLIRGPUTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 05bb9e425550d..58e65988d1df6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -64,6 +64,7 @@ void registerMemRefBoundCheck();
void registerPatternsTestPass();
void registerSimpleParametricTilingPass();
void registerTestAffineLoopParametricTilingPass();
+void registerTestArithmeticEmulateWideIntPass();
void registerTestAliasAnalysisPass();
void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
@@ -161,6 +162,7 @@ void registerTestPasses() {
mlir::test::registerSimpleParametricTilingPass();
mlir::test::registerTestAffineLoopParametricTilingPass();
mlir::test::registerTestAliasAnalysisPass();
+ mlir::test::registerTestArithmeticEmulateWideIntPass();
mlir::test::registerTestBuiltinAttributeInterfaces();
mlir::test::registerTestCallGraphPass();
mlir::test::registerTestConstantFold();
More information about the Mlir-commits
mailing list