[Mlir-commits] [mlir] 242d558 - [mlir][arith] Add test pass for wide integer emulation

Jakub Kuderski llvmlistbot at llvm.org
Tue Sep 20 08:26:03 PDT 2022


Author: Jakub Kuderski
Date: 2022-09-20T11:22:28-04:00
New Revision: 242d558658cd5a480b02883e2982d7246342e0d0

URL: https://github.com/llvm/llvm-project/commit/242d558658cd5a480b02883e2982d7246342e0d0
DIFF: https://github.com/llvm/llvm-project/commit/242d558658cd5a480b02883e2982d7246342e0d0.diff

LOG: [mlir][arith] Add test pass for wide integer emulation

The new test pass allows for running wide integer emulation conversion
within specified functions only.

I intend to use it in integration tests in a way that allows me print both
original and emulated results in the same format, or even compare both results
at runtime and print on mismatch only.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134120

Added: 
    mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir
    mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt
    mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp

Modified: 
    mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
    mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir b/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir
new file mode 100644
index 0000000000000..bc6151e1d472f
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir
@@ -0,0 +1,39 @@
+// Check that the test version of the wide integer emulation pass applies
+// conversion to functions whose name start with a given prefix only, and that
+// the function signatures are preserved.
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="function-prefix=emulate_me_" | FileCheck %s
+
+// CHECK-LABEL: func.func @entry()
+// CHECK:         {{%.+}} = call @emulate_me_please({{.+}}) : (i64) -> i64
+// CHECK-NEXT:    {{%.+}} = call @foo({{.+}}) : (i64) -> i64
+func.func @entry() {
+  %cst0 = arith.constant 0 : i64
+  func.call @emulate_me_please(%cst0) : (i64) -> (i64)
+  func.call @foo(%cst0) : (i64) -> (i64)
+  return
+}
+
+// CHECK-LABEL: func.func @emulate_me_please
+// CHECK-SAME:    ([[ARG:%.+]]: i64) -> i64 {
+// CHECK-NEXT:    [[BCAST0:%.+]]   = llvm.bitcast [[ARG]] : i64 to vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]     = vector.extract [[BCAST0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]    = vector.extract [[BCAST0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]     = vector.extract [[BCAST0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]    = vector.extract [[BCAST0]][1] : vector<2xi32>
+// CHECK-NEXT:    {{%.+}}, {{%.+}} = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1
+// CHECK:         [[RES:%.+]]      = llvm.bitcast {{%.+}} : vector<2xi32> to i64
+// CHECK-NEXt:    return [[RES]] : i64
+func.func @emulate_me_please(%x : i64) -> i64 {
+  %r = arith.addi %x, %x : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: func.func @foo
+// CHECK-SAME:    ([[ARG:%.+]]: i64) -> i64 {
+// CHECK-NEXT:    [[RES:%.+]] = arith.addi [[ARG]], [[ARG]] : i64
+// CHECK-NEXT:    return [[RES]] : i64
+func.func @foo(%x : i64) -> i64 {
+  %r = arith.addi %x, %x : i64
+  return %r : i64
+}

diff  --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
index 8cc5ceba06522..22ef5d4087459 100644
--- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir
@@ -1,7 +1,7 @@
 // Check that the wide integer constant emulation produces the same result as wide
 // constants and that printing works. Emulate i16 ops with i8 ops.
 
-// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
 // RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
 // RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
 // RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
@@ -9,6 +9,16 @@
 // RUN:   FileCheck %s --match-full-lines --check-prefix=EMULATED
 
 func.func @entry() {
+  %cst0 = arith.constant 0 : i16
+  func.call @emulate_constant(%cst0) : (i16) -> ()
+  func.call @foo(%cst0) : (i16) -> ()
+  return
+}
+
+func.func @emulate_constant(%first : i16) {
+  // EMULATED: ( 0, 0 )
+  vector.print %first : i16
+
   %cst0 = arith.constant 0 : i16
   %cst1 = arith.constant 1 : i16
   %cst_1 = arith.constant -1 : i16
@@ -20,7 +30,7 @@ func.func @entry() {
   %cst_i16_max = arith.constant 32767 : i16
   %cst_i16_min = arith.constant -32768 : i16
 
-  // EMULATED: ( 0, 0 )
+  // EMULATED-NEXT: ( 0, 0 )
   vector.print %cst0 : i16
   // EMULATED-NEXT: ( 1, 0 )
   vector.print %cst1 : i16
@@ -39,6 +49,17 @@ func.func @entry() {
   vector.print %cst_i16_max : i16
   // EMULATED-NEXT: ( 0, -128 )
   vector.print %cst_i16_min : i16
+  return
+}
 
+func.func @foo(%first: i16) {
+  // These should not be emulated because the function name does not start with
+  // 'emulated_'.
+
+  // EMULATED-NEXT: 0
+  vector.print %first : i16
+  // EMULATED-NEXT: 1
+  %cst1 = arith.constant 1 : i16
+  vector.print %cst1 : i16
   return
 }

diff  --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
index 7a56ed929d74b..976e28fe72ca8 100644
--- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir
@@ -5,17 +5,23 @@
 // RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
 // RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:                   --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN:   FileCheck %s --match-full-lines --check-prefix=WIDE
+// RUN:   FileCheck %s --match-full-lines
 
-// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
 // RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
 // RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
 // RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:                   --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN:   FileCheck %s --match-full-lines --check-prefix=EMULATED
+// RUN:   FileCheck %s --match-full-lines
 
-func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
+// Ops in this function *only* will be emulated using i8 types.
+func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) {
   %res = arith.muli %lhs, %rhs : i16
+  return %res : i16
+}
+
+func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
+  %res = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16)
   vector.print %res : i16
   return
 }
@@ -34,63 +40,45 @@ func.func @entry() {
   %cst_i16_max = arith.constant 32767 : i16
   %cst_i16_min = arith.constant -32768 : i16
 
-  // WIDE: 0
-  // EMULATED: ( 0, 0 )
+  // CHECK: 0
   func.call @check_muli(%cst0, %cst0) : (i16, i16) -> ()
-  // WIDE-NEXT: 0
-  // EMULATED-NEXT: ( 0, 0 )
+  // CHECK-NEXT: 0
   func.call @check_muli(%cst0, %cst1) : (i16, i16) -> ()
-  // WIDE-NEXT: 1
-  // EMULATED-NEXT: ( 1, 0 )
+  // CHECK-NEXT: 1
   func.call @check_muli(%cst1, %cst1) : (i16, i16) -> ()
-  // WIDE-NEXT: -1
-  // EMULATED-NEXT: ( -1, -1 )
+  // CHECK-NEXT: -1
   func.call @check_muli(%cst1, %cst_1) : (i16, i16) -> ()
-  // WIDE-NEXT: 1
-  // EMULATED-NEXT: ( 1, 0 )
+  // CHECK-NEXT: 1
   func.call @check_muli(%cst_1, %cst_1) : (i16, i16) -> ()
-  // WIDE-NEXT: -3
-  // EMULATED-NEXT: ( -3, -1 )
+  // CHECK-NEXT: -3
   func.call @check_muli(%cst1, %cst_3) : (i16, i16) -> ()
 
-  // WIDE-NEXT: 169
-  // EMULATED-NEXT: ( -87, 0 )
+  // CHECK-NEXT: 169
   func.call @check_muli(%cst13, %cst13) : (i16, i16) -> ()
-  // WIDE-NEXT: 481
-  // EMULATED-NEXT: ( -31, 1 )
+  // CHECK-NEXT: 481
   func.call @check_muli(%cst13, %cst37) : (i16, i16) -> ()
-  // WIDE-NEXT: 1554
-  // EMULATED-NEXT: ( 18, 6 )
+  // CHECK-NEXT: 1554
   func.call @check_muli(%cst37, %cst42) : (i16, i16) -> ()
 
-  // WIDE-NEXT: -256
-  // EMULATED-NEXT: ( 0, -1 )
+  // CHECK-NEXT: -256
   func.call @check_muli(%cst_1, %cst256) : (i16, i16) -> ()
-  // WIDE-NEXT: 3328
-  // EMULATED-NEXT: ( 0, 13 )
+  // CHECK-NEXT: 3328
   func.call @check_muli(%cst256, %cst13) : (i16, i16) -> ()
-  // WIDE-NEXT: 9472
-  // EMULATED-NEXT: ( 0, 37 )
+  // CHECK-NEXT: 9472
   func.call @check_muli(%cst256, %cst37) : (i16, i16) -> ()
-  // WIDE-NEXT: -768
-  // EMULATED-NEXT: ( 0, -3 )
+  // CHECK-NEXT: -768
   func.call @check_muli(%cst256, %cst_3) : (i16, i16) -> ()
 
-  // WIDE-NEXT: 32755
-  // EMULATED-NEXT: ( -13, 127 )
+  // CHECK-NEXT: 32755
   func.call @check_muli(%cst13, %cst_i16_max) : (i16, i16) -> ()
-  // WIDE-NEXT: -32768
-  // EMULATED-NEXT: ( 0, -128 )
+  // CHECK-NEXT: -32768
   func.call @check_muli(%cst_i16_min, %cst37) : (i16, i16) -> ()
 
-  // WIDE-NEXT: 1
-  // EMULATED-NEXT: ( 1, 0 )
+  // CHECK-NEXT: 1
   func.call @check_muli(%cst_i16_max, %cst_i16_max) : (i16, i16) -> ()
-  // WIDE-NEXT: -32768
-  // EMULATED-NEXT: ( 0, -128 )
+  // CHECK-NEXT: -32768
   func.call @check_muli(%cst_i16_min, %cst13) : (i16, i16) -> ()
-  // WIDE-NEXT: 0
-  // EMULATED-NEXT: ( 0, 0 )
+  // CHECK-NEXT: 0
   func.call @check_muli(%cst_i16_min, %cst_i16_min) : (i16, i16) -> ()
 
   return

diff  --git a/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt
new file mode 100644
index 0000000000000..17d288e7a2a5f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRArithmeticTestPasses
+  TestEmulateWideInt.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRArithmeticDialect
+  MLIRArithmeticTransforms
+  MLIRFuncDialect
+  MLIRLLVMDialect
+  MLIRPass
+  MLIRVectorDialect
+)

diff  --git a/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp
new file mode 100644
index 0000000000000..7cf76ad01502f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp
@@ -0,0 +1,95 @@
+//===- TestWideIntEmulation.cpp - Test Wide Int Emulation  ------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for integration testing of wide integer
+// emulation patterns. Applies conversion patterns only to functions whose
+// names start with a specified prefix.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct TestEmulateWideIntPass
+    : public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
+
+  TestEmulateWideIntPass() = default;
+  TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithmeticDialect, func::FuncDialect,
+                    LLVM::LLVMDialect, vector::VectorDialect>();
+  }
+  StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
+  StringRef getDescription() const final {
+    return "Function pass to test Wide Integer Emulation";
+  }
+
+  void runOnOperation() override {
+    if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
+      signalPassFailure();
+      return;
+    }
+
+    func::FuncOp op = getOperation();
+    if (!op.getSymName().startswith(testFunctionPrefix))
+      return;
+
+    MLIRContext *ctx = op.getContext();
+    arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+
+    // Use `llvm.bitcast` as the bridge so that we can use preserve the
+    // function argument and return types of the processed function.
+    // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
+    // casts (and vice versa) and using it insted of `llvm.bitcast`.
+    auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                         Location loc) -> Optional<Value> {
+      auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
+      return cast->getResult(0);
+    };
+    typeConverter.addSourceMaterialization(addBitcast);
+    typeConverter.addTargetMaterialization(addBitcast);
+
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
+                                      vector::VectorDialect>(
+        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+    RewritePatternSet patterns(ctx);
+    arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+
+  Option<std::string> testFunctionPrefix{
+      *this, "function-prefix",
+      llvm::cl::desc("Prefix of functions to run the emulation pass on"),
+      llvm::cl::init("emulate_")};
+  Option<unsigned> widestIntSupported{
+      *this, "widest-int-supported",
+      llvm::cl::desc("Maximum integer bit width supported by the target"),
+      llvm::cl::init(32)};
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestArithmeticEmulateWideIntPass() {
+  PassRegistration<TestEmulateWideIntPass>();
+}
+} // namespace mlir::test

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 46b38dcfbc736..002e484f8195b 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(Affine)
+add_subdirectory(Arithmetic)
 add_subdirectory(DLTI)
 add_subdirectory(Func)
 add_subdirectory(GPU)

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index e1a90956e122b..3b27cfda044d1 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -14,6 +14,7 @@ if(MLIR_INCLUDE_TESTS)
   set(test_libs
     MLIRTestFuncToLLVM
     MLIRAffineTransformsTestPasses
+    MLIRArithmeticTestPasses
     MLIRDLTITestPasses
     MLIRFuncTestPasses
     MLIRGPUTestPasses

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 05bb9e425550d..58e65988d1df6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -64,6 +64,7 @@ void registerMemRefBoundCheck();
 void registerPatternsTestPass();
 void registerSimpleParametricTilingPass();
 void registerTestAffineLoopParametricTilingPass();
+void registerTestArithmeticEmulateWideIntPass();
 void registerTestAliasAnalysisPass();
 void registerTestBuiltinAttributeInterfaces();
 void registerTestCallGraphPass();
@@ -161,6 +162,7 @@ void registerTestPasses() {
   mlir::test::registerSimpleParametricTilingPass();
   mlir::test::registerTestAffineLoopParametricTilingPass();
   mlir::test::registerTestAliasAnalysisPass();
+  mlir::test::registerTestArithmeticEmulateWideIntPass();
   mlir::test::registerTestBuiltinAttributeInterfaces();
   mlir::test::registerTestCallGraphPass();
   mlir::test::registerTestConstantFold();


        


More information about the Mlir-commits mailing list