[Mlir-commits] [mlir] b8880f5 - [mlir] [VectorOps] generalize printing support for integers

Aart Bik llvmlistbot at llvm.org
Fri Sep 25 04:52:42 PDT 2020


Author: Aart Bik
Date: 2020-09-25T04:52:21-07:00
New Revision: b8880f5f97bf1628b2c9606e96abcd612dc7d747

URL: https://github.com/llvm/llvm-project/commit/b8880f5f97bf1628b2c9606e96abcd612dc7d747
DIFF: https://github.com/llvm/llvm-project/commit/b8880f5f97bf1628b2c9606e96abcd612dc7d747.diff

LOG: [mlir] [VectorOps] generalize printing support for integers

This generalizes printing beyond just i1,i32,i64 and also accounts
for signed and unsigned interpretation in the output.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D88290

Added: 
    mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/ExecutionEngine/CRunnerUtils.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir
new file mode 100644
index 000000000000..946f02b0e3b9
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+//
+// Test various signless, signed, unsigned integer types.
+//
+func @entry() {
+  %0 = std.constant dense<[true, false, -1, 0, 1]> : vector<5xi1>
+  vector.print %0 : vector<5xi1>
+  // CHECK: ( 1, 0, 1, 0, 1 )
+
+  %1 = std.constant dense<[true, false, -1, 0]> : vector<4xsi1>
+  vector.print %1 : vector<4xsi1>
+  // CHECK: ( 1, 0, 1, 0 )
+
+  %2 = std.constant dense<[true, false, 0, 1]> : vector<4xui1>
+  vector.print %2 : vector<4xui1>
+  // CHECK: ( 1, 0, 0, 1 )
+
+  %3 = std.constant dense<[-128, -127, -1, 0, 1, 127, 128, 254, 255]> : vector<9xi8>
+  vector.print %3 : vector<9xi8>
+  // CHECK: ( -128, -127, -1, 0, 1, 127, -128, -2, -1 )
+
+  %4 = std.constant dense<[-128, -127, -1, 0, 1, 127]> : vector<6xsi8>
+  vector.print %4 : vector<6xsi8>
+  // CHECK: ( -128, -127, -1, 0, 1, 127 )
+
+  %5 = std.constant dense<[0, 1, 127, 128, 254, 255]> : vector<6xui8>
+  vector.print %5 : vector<6xui8>
+  // CHECK: ( 0, 1, 127, 128, 254, 255 )
+
+  %6 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767, 32768, 65534, 65535]> : vector<9xi16>
+  vector.print %6 : vector<9xi16>
+  // CHECK: ( -32768, -32767, -1, 0, 1, 32767, -32768, -2, -1 )
+
+  %7 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767]> : vector<6xsi16>
+  vector.print %7 : vector<6xsi16>
+  // CHECK: ( -32768, -32767, -1, 0, 1, 32767 )
+
+  %8 = std.constant dense<[0, 1, 32767, 32768, 65534, 65535]> : vector<6xui16>
+  vector.print %8 : vector<6xui16>
+  // CHECK: ( 0, 1, 32767, 32768, 65534, 65535 )
+
+  %9 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1,
+                            2147483647, 2147483648, 4294967294, 4294967295]> : vector<9xi32>
+  vector.print %9 : vector<9xi32>
+  // CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647, -2147483648, -2, -1 )
+
+  %10 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1, 2147483647]> : vector<6xsi32>
+  vector.print %10 : vector<6xsi32>
+  // CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647 )
+
+  %11 = std.constant dense<[0, 1, 2147483647, 2147483648, 4294967294, 4294967295]> : vector<6xui32>
+  vector.print %11 : vector<6xui32>
+  // CHECK: ( 0, 1, 2147483647, 2147483648, 4294967294, 4294967295 )
+
+  %12 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1,
+                             9223372036854775807, 9223372036854775808,
+                             18446744073709551614, 18446744073709551615]> : vector<9xi64>
+  vector.print %12 : vector<9xi64>
+  // CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807, -9223372036854775808, -2, -1 )
+
+  %13 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1,
+                             9223372036854775807]> : vector<6xsi64>
+  vector.print %13 : vector<6xsi64>
+  // CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807 )
+
+  %14 = std.constant dense<[0, 1, 9223372036854775807, 9223372036854775808,
+                            18446744073709551614, 18446744073709551615]> : vector<6xui64>
+  vector.print %14 : vector<6xui64>
+  // CHECK: ( 0, 1, 9223372036854775807, 9223372036854775808, 18446744073709551614, 18446744073709551615 )
+
+  return
+}

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6ad17d77069c..b48b435d0278 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1319,44 +1319,96 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
     if (typeConverter.convertType(printType) == nullptr)
       return failure();
 
-    // Make sure element type has runtime support (currently just Float/Double).
+    // Make sure element type has runtime support.
+    PrintConversion conversion = PrintConversion::None;
     VectorType vectorType = printType.dyn_cast<VectorType>();
     Type eltType = vectorType ? vectorType.getElementType() : printType;
-    int64_t rank = vectorType ? vectorType.getRank() : 0;
     Operation *printer;
-    if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
-      printer = getPrintI32(op);
-    else if (eltType.isSignlessInteger(64))
-      printer = getPrintI64(op);
-    else if (eltType.isF32())
+    if (eltType.isF32()) {
       printer = getPrintFloat(op);
-    else if (eltType.isF64())
+    } else if (eltType.isF64()) {
       printer = getPrintDouble(op);
-    else
+    } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
+      // Integers need a zero or sign extension on the operand
+      // (depending on the source type) as well as a signed or
+      // unsigned print method. Up to 64-bit is supported.
+      unsigned width = intTy.getWidth();
+      if (intTy.isUnsigned()) {
+        if (width <= 32) {
+          if (width < 32)
+            conversion = PrintConversion::ZeroExt32;
+          printer = getPrintU32(op);
+        } else if (width <= 64) {
+          if (width < 64)
+            conversion = PrintConversion::ZeroExt64;
+          printer = getPrintU64(op);
+        } else {
+          return failure();
+        }
+      } else {
+        assert(intTy.isSignless() || intTy.isSigned());
+        if (width <= 32) {
+          // Note that we *always* zero extend booleans (1-bit integers),
+          // so that true/false is printed as 1/0 rather than -1/0.
+          if (width == 1)
+            conversion = PrintConversion::ZeroExt32;
+          else if (width < 32)
+            conversion = PrintConversion::SignExt32;
+          printer = getPrintI32(op);
+        } else if (width <= 64) {
+          if (width < 64)
+            conversion = PrintConversion::SignExt64;
+          printer = getPrintI64(op);
+        } else {
+          return failure();
+        }
+      }
+    } else {
       return failure();
+    }
 
     // Unroll vector into elementary print calls.
-    emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
+    int64_t rank = vectorType ? vectorType.getRank() : 0;
+    emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
+              conversion);
     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
     rewriter.eraseOp(op);
     return success();
   }
 
 private:
+  enum class PrintConversion {
+    None,
+    ZeroExt32,
+    SignExt32,
+    ZeroExt64,
+    SignExt64
+  };
+
   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
                  Value value, VectorType vectorType, Operation *printer,
-                 int64_t rank) const {
+                 int64_t rank, PrintConversion conversion) const {
     Location loc = op->getLoc();
     if (rank == 0) {
-      if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
-        // Convert i1 (bool) to i32 so we can use the print_i32 method.
-        // This avoids the need for a print_i1 method with an unclear ABI.
-        auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
-        auto trueVal = rewriter.create<ConstantOp>(
-            loc, i32Type, rewriter.getI32IntegerAttr(1));
-        auto falseVal = rewriter.create<ConstantOp>(
-            loc, i32Type, rewriter.getI32IntegerAttr(0));
-        value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
+      switch (conversion) {
+      case PrintConversion::ZeroExt32:
+        value = rewriter.create<ZeroExtendIOp>(
+            loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
+        break;
+      case PrintConversion::SignExt32:
+        value = rewriter.create<SignExtendIOp>(
+            loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
+        break;
+      case PrintConversion::ZeroExt64:
+        value = rewriter.create<ZeroExtendIOp>(
+            loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
+        break;
+      case PrintConversion::SignExt64:
+        value = rewriter.create<SignExtendIOp>(
+            loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
+        break;
+      case PrintConversion::None:
+        break;
       }
       emitCall(rewriter, loc, printer, value);
       return;
@@ -1372,7 +1424,8 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
           rank > 1 ? reducedType : vectorType.getElementType());
       Value nestedVal =
           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
-      emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
+      emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
+                conversion);
       if (d != dim - 1)
         emitCall(rewriter, loc, printComma);
     }
@@ -1410,6 +1463,14 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
     return getPrint(op, "print_i64",
                     LLVM::LLVMType::getInt64Ty(op->getContext()));
   }
+  Operation *getPrintU32(Operation *op) const {
+    return getPrint(op, "printU32",
+                    LLVM::LLVMType::getInt32Ty(op->getContext()));
+  }
+  Operation *getPrintU64(Operation *op) const {
+    return getPrint(op, "printU64",
+                    LLVM::LLVMType::getInt64Ty(op->getContext()));
+  }
   Operation *getPrintFloat(Operation *op) const {
     return getPrint(op, "print_f32",
                     LLVM::LLVMType::getFloatTy(op->getContext()));

diff  --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
index ad5be24378ce..6efc48768a96 100644
--- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
@@ -25,6 +25,8 @@
 // details of our vectors. Also useful for direct LLVM IR output.
 extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); }
 extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); }
+extern "C" void printU32(uint32_t i) { fprintf(stdout, "%" PRIu32, i); }
+extern "C" void printU64(uint64_t l) { fprintf(stdout, "%" PRIu64, l); }
 extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); }
 extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); }
 extern "C" void print_open() { fputs("( ", stdout); }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 82db2c55c906..d382c50f9132 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -433,14 +433,45 @@ func @vector_print_scalar_i1(%arg0: i1) {
   vector.print %arg0 : i1
   return
 }
+//
+// Type "boolean" always uses zero extension.
+//
 // CHECK-LABEL: llvm.func @vector_print_scalar_i1(
 // CHECK-SAME: %[[A:.*]]: !llvm.i1)
-//       CHECK: %[[T:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
-//       CHECK: %[[F:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
-//       CHECK: %[[S:.*]] = llvm.select %[[A]], %[[T]], %[[F]] : !llvm.i1, !llvm.i32
+//       CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i1 to !llvm.i32
+//       CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
+//       CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_i4(%arg0: i4) {
+  vector.print %arg0 : i4
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_i4(
+// CHECK-SAME: %[[A:.*]]: !llvm.i4)
+//       CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32
 //       CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
 //       CHECK: llvm.call @print_newline() : () -> ()
 
+func @vector_print_scalar_si4(%arg0: si4) {
+  vector.print %arg0 : si4
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_si4(
+// CHECK-SAME: %[[A:.*]]: !llvm.i4)
+//       CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32
+//       CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
+//       CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_ui4(%arg0: ui4) {
+  vector.print %arg0 : ui4
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui4(
+// CHECK-SAME: %[[A:.*]]: !llvm.i4)
+//       CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i4 to !llvm.i32
+//       CHECK: llvm.call @printU32(%[[S]]) : (!llvm.i32) -> ()
+//       CHECK: llvm.call @print_newline() : () -> ()
+
 func @vector_print_scalar_i32(%arg0: i32) {
   vector.print %arg0 : i32
   return
@@ -450,6 +481,45 @@ func @vector_print_scalar_i32(%arg0: i32) {
 //       CHECK:    llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
+func @vector_print_scalar_ui32(%arg0: ui32) {
+  vector.print %arg0 : ui32
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui32(
+// CHECK-SAME: %[[A:.*]]: !llvm.i32)
+//       CHECK:    llvm.call @printU32(%[[A]]) : (!llvm.i32) -> ()
+//       CHECK:    llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_i40(%arg0: i40) {
+  vector.print %arg0 : i40
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_i40(
+// CHECK-SAME: %[[A:.*]]: !llvm.i40)
+//       CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64
+//       CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> ()
+//       CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_si40(%arg0: si40) {
+  vector.print %arg0 : si40
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_si40(
+// CHECK-SAME: %[[A:.*]]: !llvm.i40)
+//       CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64
+//       CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> ()
+//       CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_ui40(%arg0: ui40) {
+  vector.print %arg0 : ui40
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui40(
+// CHECK-SAME: %[[A:.*]]: !llvm.i40)
+//       CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i40 to !llvm.i64
+//       CHECK: llvm.call @printU64(%[[S]]) : (!llvm.i64) -> ()
+//       CHECK: llvm.call @print_newline() : () -> ()
+
 func @vector_print_scalar_i64(%arg0: i64) {
   vector.print %arg0 : i64
   return
@@ -459,6 +529,15 @@ func @vector_print_scalar_i64(%arg0: i64) {
 //       CHECK:    llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
+func @vector_print_scalar_ui64(%arg0: ui64) {
+  vector.print %arg0 : ui64
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui64(
+// CHECK-SAME: %[[A:.*]]: !llvm.i64)
+//       CHECK:    llvm.call @printU64(%[[A]]) : (!llvm.i64) -> ()
+//       CHECK:    llvm.call @print_newline() : () -> ()
+
 func @vector_print_scalar_f32(%arg0: f32) {
   vector.print %arg0 : f32
   return


        


More information about the Mlir-commits mailing list