[Mlir-commits] [mlir] b8880f5 - [mlir] [VectorOps] generalize printing support for integers
Aart Bik
llvmlistbot at llvm.org
Fri Sep 25 04:52:42 PDT 2020
Author: Aart Bik
Date: 2020-09-25T04:52:21-07:00
New Revision: b8880f5f97bf1628b2c9606e96abcd612dc7d747
URL: https://github.com/llvm/llvm-project/commit/b8880f5f97bf1628b2c9606e96abcd612dc7d747
DIFF: https://github.com/llvm/llvm-project/commit/b8880f5f97bf1628b2c9606e96abcd612dc7d747.diff
LOG: [mlir] [VectorOps] generalize printing support for integers
This generalizes printing beyond just i1,i32,i64 and also accounts
for signed and unsigned interpretation in the output.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D88290
Added:
mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/ExecutionEngine/CRunnerUtils.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir
new file mode 100644
index 000000000000..946f02b0e3b9
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-print-int.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+//
+// Test various signless, signed, unsigned integer types.
+//
+func @entry() {
+ %0 = std.constant dense<[true, false, -1, 0, 1]> : vector<5xi1>
+ vector.print %0 : vector<5xi1>
+ // CHECK: ( 1, 0, 1, 0, 1 )
+
+ %1 = std.constant dense<[true, false, -1, 0]> : vector<4xsi1>
+ vector.print %1 : vector<4xsi1>
+ // CHECK: ( 1, 0, 1, 0 )
+
+ %2 = std.constant dense<[true, false, 0, 1]> : vector<4xui1>
+ vector.print %2 : vector<4xui1>
+ // CHECK: ( 1, 0, 0, 1 )
+
+ %3 = std.constant dense<[-128, -127, -1, 0, 1, 127, 128, 254, 255]> : vector<9xi8>
+ vector.print %3 : vector<9xi8>
+ // CHECK: ( -128, -127, -1, 0, 1, 127, -128, -2, -1 )
+
+ %4 = std.constant dense<[-128, -127, -1, 0, 1, 127]> : vector<6xsi8>
+ vector.print %4 : vector<6xsi8>
+ // CHECK: ( -128, -127, -1, 0, 1, 127 )
+
+ %5 = std.constant dense<[0, 1, 127, 128, 254, 255]> : vector<6xui8>
+ vector.print %5 : vector<6xui8>
+ // CHECK: ( 0, 1, 127, 128, 254, 255 )
+
+ %6 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767, 32768, 65534, 65535]> : vector<9xi16>
+ vector.print %6 : vector<9xi16>
+ // CHECK: ( -32768, -32767, -1, 0, 1, 32767, -32768, -2, -1 )
+
+ %7 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767]> : vector<6xsi16>
+ vector.print %7 : vector<6xsi16>
+ // CHECK: ( -32768, -32767, -1, 0, 1, 32767 )
+
+ %8 = std.constant dense<[0, 1, 32767, 32768, 65534, 65535]> : vector<6xui16>
+ vector.print %8 : vector<6xui16>
+ // CHECK: ( 0, 1, 32767, 32768, 65534, 65535 )
+
+ %9 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1,
+ 2147483647, 2147483648, 4294967294, 4294967295]> : vector<9xi32>
+ vector.print %9 : vector<9xi32>
+ // CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647, -2147483648, -2, -1 )
+
+ %10 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1, 2147483647]> : vector<6xsi32>
+ vector.print %10 : vector<6xsi32>
+ // CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647 )
+
+ %11 = std.constant dense<[0, 1, 2147483647, 2147483648, 4294967294, 4294967295]> : vector<6xui32>
+ vector.print %11 : vector<6xui32>
+ // CHECK: ( 0, 1, 2147483647, 2147483648, 4294967294, 4294967295 )
+
+ %12 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1,
+ 9223372036854775807, 9223372036854775808,
+ 18446744073709551614, 18446744073709551615]> : vector<9xi64>
+ vector.print %12 : vector<9xi64>
+ // CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807, -9223372036854775808, -2, -1 )
+
+ %13 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1,
+ 9223372036854775807]> : vector<6xsi64>
+ vector.print %13 : vector<6xsi64>
+ // CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807 )
+
+ %14 = std.constant dense<[0, 1, 9223372036854775807, 9223372036854775808,
+ 18446744073709551614, 18446744073709551615]> : vector<6xui64>
+ vector.print %14 : vector<6xui64>
+ // CHECK: ( 0, 1, 9223372036854775807, 9223372036854775808, 18446744073709551614, 18446744073709551615 )
+
+ return
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6ad17d77069c..b48b435d0278 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1319,44 +1319,96 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
if (typeConverter.convertType(printType) == nullptr)
return failure();
- // Make sure element type has runtime support (currently just Float/Double).
+ // Make sure element type has runtime support.
+ PrintConversion conversion = PrintConversion::None;
VectorType vectorType = printType.dyn_cast<VectorType>();
Type eltType = vectorType ? vectorType.getElementType() : printType;
- int64_t rank = vectorType ? vectorType.getRank() : 0;
Operation *printer;
- if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
- printer = getPrintI32(op);
- else if (eltType.isSignlessInteger(64))
- printer = getPrintI64(op);
- else if (eltType.isF32())
+ if (eltType.isF32()) {
printer = getPrintFloat(op);
- else if (eltType.isF64())
+ } else if (eltType.isF64()) {
printer = getPrintDouble(op);
- else
+ } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
+ // Integers need a zero or sign extension on the operand
+ // (depending on the source type) as well as a signed or
+ // unsigned print method. Up to 64-bit is supported.
+ unsigned width = intTy.getWidth();
+ if (intTy.isUnsigned()) {
+ if (width <= 32) {
+ if (width < 32)
+ conversion = PrintConversion::ZeroExt32;
+ printer = getPrintU32(op);
+ } else if (width <= 64) {
+ if (width < 64)
+ conversion = PrintConversion::ZeroExt64;
+ printer = getPrintU64(op);
+ } else {
+ return failure();
+ }
+ } else {
+ assert(intTy.isSignless() || intTy.isSigned());
+ if (width <= 32) {
+ // Note that we *always* zero extend booleans (1-bit integers),
+ // so that true/false is printed as 1/0 rather than -1/0.
+ if (width == 1)
+ conversion = PrintConversion::ZeroExt32;
+ else if (width < 32)
+ conversion = PrintConversion::SignExt32;
+ printer = getPrintI32(op);
+ } else if (width <= 64) {
+ if (width < 64)
+ conversion = PrintConversion::SignExt64;
+ printer = getPrintI64(op);
+ } else {
+ return failure();
+ }
+ }
+ } else {
return failure();
+ }
// Unroll vector into elementary print calls.
- emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
+ int64_t rank = vectorType ? vectorType.getRank() : 0;
+ emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
+ conversion);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
return success();
}
private:
+ enum class PrintConversion {
+ None,
+ ZeroExt32,
+ SignExt32,
+ ZeroExt64,
+ SignExt64
+ };
+
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, VectorType vectorType, Operation *printer,
- int64_t rank) const {
+ int64_t rank, PrintConversion conversion) const {
Location loc = op->getLoc();
if (rank == 0) {
- if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
- // Convert i1 (bool) to i32 so we can use the print_i32 method.
- // This avoids the need for a print_i1 method with an unclear ABI.
- auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
- auto trueVal = rewriter.create<ConstantOp>(
- loc, i32Type, rewriter.getI32IntegerAttr(1));
- auto falseVal = rewriter.create<ConstantOp>(
- loc, i32Type, rewriter.getI32IntegerAttr(0));
- value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
+ switch (conversion) {
+ case PrintConversion::ZeroExt32:
+ value = rewriter.create<ZeroExtendIOp>(
+ loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
+ break;
+ case PrintConversion::SignExt32:
+ value = rewriter.create<SignExtendIOp>(
+ loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
+ break;
+ case PrintConversion::ZeroExt64:
+ value = rewriter.create<ZeroExtendIOp>(
+ loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
+ break;
+ case PrintConversion::SignExt64:
+ value = rewriter.create<SignExtendIOp>(
+ loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
+ break;
+ case PrintConversion::None:
+ break;
}
emitCall(rewriter, loc, printer, value);
return;
@@ -1372,7 +1424,8 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
rank > 1 ? reducedType : vectorType.getElementType());
Value nestedVal =
extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
- emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
+ emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
+ conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
@@ -1410,6 +1463,14 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
return getPrint(op, "print_i64",
LLVM::LLVMType::getInt64Ty(op->getContext()));
}
+ Operation *getPrintU32(Operation *op) const {
+ return getPrint(op, "printU32",
+ LLVM::LLVMType::getInt32Ty(op->getContext()));
+ }
+ Operation *getPrintU64(Operation *op) const {
+ return getPrint(op, "printU64",
+ LLVM::LLVMType::getInt64Ty(op->getContext()));
+ }
Operation *getPrintFloat(Operation *op) const {
return getPrint(op, "print_f32",
LLVM::LLVMType::getFloatTy(op->getContext()));
diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
index ad5be24378ce..6efc48768a96 100644
--- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
@@ -25,6 +25,8 @@
// details of our vectors. Also useful for direct LLVM IR output.
extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); }
extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); }
+extern "C" void printU32(uint32_t i) { fprintf(stdout, "%" PRIu32, i); }
+extern "C" void printU64(uint64_t l) { fprintf(stdout, "%" PRIu64, l); }
extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); }
extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); }
extern "C" void print_open() { fputs("( ", stdout); }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 82db2c55c906..d382c50f9132 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -433,14 +433,45 @@ func @vector_print_scalar_i1(%arg0: i1) {
vector.print %arg0 : i1
return
}
+//
+// Type "boolean" always uses zero extension.
+//
// CHECK-LABEL: llvm.func @vector_print_scalar_i1(
// CHECK-SAME: %[[A:.*]]: !llvm.i1)
-// CHECK: %[[T:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
-// CHECK: %[[F:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
-// CHECK: %[[S:.*]] = llvm.select %[[A]], %[[T]], %[[F]] : !llvm.i1, !llvm.i32
+// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i1 to !llvm.i32
+// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_i4(%arg0: i4) {
+ vector.print %arg0 : i4
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_i4(
+// CHECK-SAME: %[[A:.*]]: !llvm.i4)
+// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32
// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
+func @vector_print_scalar_si4(%arg0: si4) {
+ vector.print %arg0 : si4
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_si4(
+// CHECK-SAME: %[[A:.*]]: !llvm.i4)
+// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32
+// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_ui4(%arg0: ui4) {
+ vector.print %arg0 : ui4
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui4(
+// CHECK-SAME: %[[A:.*]]: !llvm.i4)
+// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i4 to !llvm.i32
+// CHECK: llvm.call @printU32(%[[S]]) : (!llvm.i32) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
func @vector_print_scalar_i32(%arg0: i32) {
vector.print %arg0 : i32
return
@@ -450,6 +481,45 @@ func @vector_print_scalar_i32(%arg0: i32) {
// CHECK: llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
+func @vector_print_scalar_ui32(%arg0: ui32) {
+ vector.print %arg0 : ui32
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui32(
+// CHECK-SAME: %[[A:.*]]: !llvm.i32)
+// CHECK: llvm.call @printU32(%[[A]]) : (!llvm.i32) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_i40(%arg0: i40) {
+ vector.print %arg0 : i40
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_i40(
+// CHECK-SAME: %[[A:.*]]: !llvm.i40)
+// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64
+// CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_si40(%arg0: si40) {
+ vector.print %arg0 : si40
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_si40(
+// CHECK-SAME: %[[A:.*]]: !llvm.i40)
+// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64
+// CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_ui40(%arg0: ui40) {
+ vector.print %arg0 : ui40
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui40(
+// CHECK-SAME: %[[A:.*]]: !llvm.i40)
+// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i40 to !llvm.i64
+// CHECK: llvm.call @printU64(%[[S]]) : (!llvm.i64) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
func @vector_print_scalar_i64(%arg0: i64) {
vector.print %arg0 : i64
return
@@ -459,6 +529,15 @@ func @vector_print_scalar_i64(%arg0: i64) {
// CHECK: llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
+func @vector_print_scalar_ui64(%arg0: ui64) {
+ vector.print %arg0 : ui64
+ return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_ui64(
+// CHECK-SAME: %[[A:.*]]: !llvm.i64)
+// CHECK: llvm.call @printU64(%[[A]]) : (!llvm.i64) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
func @vector_print_scalar_f32(%arg0: f32) {
vector.print %arg0 : f32
return
More information about the Mlir-commits
mailing list