[Mlir-commits] [mlir] cc311a1 - [mlir][Vector] Support 0-D vectors in `VectorPrintOpConversion`
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Nov 25 12:16:22 PST 2021
Author: Michal Terepeta
Date: 2021-11-25T20:12:18Z
New Revision: cc311a155aa9e3d7ba67ec6d65948952a314c307
URL: https://github.com/llvm/llvm-project/commit/cc311a155aa9e3d7ba67ec6d65948952a314c307
DIFF: https://github.com/llvm/llvm-project/commit/cc311a155aa9e3d7ba67ec6d65948952a314c307.diff
LOG: [mlir][Vector] Support 0-D vectors in `VectorPrintOpConversion`
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D114549
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 108e664f03cad..5ea34d03bec79 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -57,8 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) {
- assert(rank > 0 && "0-D vector corner case should have been handled already");
- if (rank == 1) {
+ if (rank <= 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(idxType),
@@ -987,7 +986,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
- emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
+ Type type = vectorType ? vectorType : eltType;
+ emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(
@@ -1006,10 +1006,12 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
};
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
- Value value, VectorType vectorType, Operation *printer,
- int64_t rank, PrintConversion conversion) const {
+ Value value, Type type, Operation *printer, int64_t rank,
+ PrintConversion conversion) const {
+ VectorType vectorType = type.dyn_cast<VectorType>();
Location loc = op->getLoc();
- if (rank == 0) {
+ if (!vectorType) {
+ assert(rank == 0 && "The scalar case expects rank == 0");
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
@@ -1030,12 +1032,29 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
Operation *printComma =
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
+
+ if (rank <= 1) {
+ auto reducedType = vectorType.getElementType();
+ auto llvmType = typeConverter->convertType(reducedType);
+ int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
+ for (int64_t d = 0; d < dim; ++d) {
+ Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
+ llvmType, /*rank=*/0, /*pos=*/d);
+ emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
+ conversion);
+ if (d != dim - 1)
+ emitCall(rewriter, loc, printComma);
+ }
+ emitCall(
+ rewriter, loc,
+ LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
+ return;
+ }
+
int64_t dim = vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
- auto reducedType =
- rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
- auto llvmType = typeConverter->convertType(
- rank > 1 ? reducedType : vectorType.getElementType());
+ auto reducedType = reducedVectorTypeFront(vectorType);
+ auto llvmType = typeConverter->convertType(reducedType);
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 033e2d812b3f2..c700b6bcb5d49 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -832,6 +832,23 @@ func @vector_print_scalar_f64(%arg0: f64) {
// -----
+func @vector_print_vector_0d(%arg0: vector<f32>) {
+ vector.print %arg0 : vector<f32>
+ return
+}
+// CHECK-LABEL: @vector_print_vector_0d(
+// CHECK-SAME: %[[A:.*]]: vector<f32>)
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+// CHECK: llvm.call @printOpen() : () -> ()
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xf32>
+// CHECK: llvm.call @printF32(%[[T2]]) : (f32) -> ()
+// CHECK: llvm.call @printClose() : () -> ()
+// CHECK: llvm.call @printNewline() : () -> ()
+// CHECK: return
+
+// -----
+
func @vector_print_vector(%arg0: vector<2x2xf32>) {
vector.print %arg0 : vector<2x2xf32>
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index b3052ebd7a600..e7cbece4b1ed3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -15,10 +15,20 @@ func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
return %1: vector<f32>
}
+func @print_vector_0d(%a: vector<f32>) {
+ // CHECK: ( 42 )
+ vector.print %a: vector<f32>
+ return
+}
+
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
%2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
call @extract_element_0d(%2) : (vector<f32>) -> ()
+
+ %3 = arith.constant dense<42.0> : vector<f32>
+ call @print_vector_0d(%3) : (vector<f32>) -> ()
+
return
}
More information about the Mlir-commits
mailing list