[Mlir-commits] [mlir] cc311a1 - [mlir][Vector] Support 0-D vectors in `VectorPrintOpConversion`

Nicolas Vasilache llvmlistbot at llvm.org
Thu Nov 25 12:16:22 PST 2021


Author: Michal Terepeta
Date: 2021-11-25T20:12:18Z
New Revision: cc311a155aa9e3d7ba67ec6d65948952a314c307

URL: https://github.com/llvm/llvm-project/commit/cc311a155aa9e3d7ba67ec6d65948952a314c307
DIFF: https://github.com/llvm/llvm-project/commit/cc311a155aa9e3d7ba67ec6d65948952a314c307.diff

LOG: [mlir][Vector] Support 0-D vectors in `VectorPrintOpConversion`

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114549

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 108e664f03cad..5ea34d03bec79 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -57,8 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
 static Value extractOne(ConversionPatternRewriter &rewriter,
                         LLVMTypeConverter &typeConverter, Location loc,
                         Value val, Type llvmType, int64_t rank, int64_t pos) {
-  assert(rank > 0 && "0-D vector corner case should have been handled already");
-  if (rank == 1) {
+  if (rank <= 1) {
     auto idxType = rewriter.getIndexType();
     auto constant = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter.convertType(idxType),
@@ -987,7 +986,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 
     // Unroll vector into elementary print calls.
     int64_t rank = vectorType ? vectorType.getRank() : 0;
-    emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
+    Type type = vectorType ? vectorType : eltType;
+    emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
               conversion);
     emitCall(rewriter, printOp->getLoc(),
              LLVM::lookupOrCreatePrintNewlineFn(
@@ -1006,10 +1006,12 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   };
 
   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
-                 Value value, VectorType vectorType, Operation *printer,
-                 int64_t rank, PrintConversion conversion) const {
+                 Value value, Type type, Operation *printer, int64_t rank,
+                 PrintConversion conversion) const {
+    VectorType vectorType = type.dyn_cast<VectorType>();
     Location loc = op->getLoc();
-    if (rank == 0) {
+    if (!vectorType) {
+      assert(rank == 0 && "The scalar case expects rank == 0");
       switch (conversion) {
       case PrintConversion::ZeroExt64:
         value = rewriter.create<arith::ExtUIOp>(
@@ -1030,12 +1032,29 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
     Operation *printComma =
         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
+
+    if (rank <= 1) {
+      auto reducedType = vectorType.getElementType();
+      auto llvmType = typeConverter->convertType(reducedType);
+      int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
+      for (int64_t d = 0; d < dim; ++d) {
+        Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
+                                     llvmType, /*rank=*/0, /*pos=*/d);
+        emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
+                  conversion);
+        if (d != dim - 1)
+          emitCall(rewriter, loc, printComma);
+      }
+      emitCall(
+          rewriter, loc,
+          LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
+      return;
+    }
+
     int64_t dim = vectorType.getDimSize(0);
     for (int64_t d = 0; d < dim; ++d) {
-      auto reducedType =
-          rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
-      auto llvmType = typeConverter->convertType(
-          rank > 1 ? reducedType : vectorType.getElementType());
+      auto reducedType = reducedVectorTypeFront(vectorType);
+      auto llvmType = typeConverter->convertType(reducedType);
       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
                                    llvmType, rank, d);
       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 033e2d812b3f2..c700b6bcb5d49 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -832,6 +832,23 @@ func @vector_print_scalar_f64(%arg0: f64) {
 
 // -----
 
+func @vector_print_vector_0d(%arg0: vector<f32>) {
+  vector.print %arg0 : vector<f32>
+  return
+}
+// CHECK-LABEL: @vector_print_vector_0d(
+// CHECK-SAME: %[[A:.*]]: vector<f32>)
+//       CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+//       CHECK: llvm.call @printOpen() : () -> ()
+//       CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xf32>
+//       CHECK: llvm.call @printF32(%[[T2]]) : (f32) -> ()
+//       CHECK: llvm.call @printClose() : () -> ()
+//       CHECK: llvm.call @printNewline() : () -> ()
+//       CHECK: return
+
+// -----
+
 func @vector_print_vector(%arg0: vector<2x2xf32>) {
   vector.print %arg0 : vector<2x2xf32>
   return

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index b3052ebd7a600..e7cbece4b1ed3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -15,10 +15,20 @@ func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
   return %1: vector<f32>
 }
 
+func @print_vector_0d(%a: vector<f32>) {
+  // CHECK: ( 42 )
+  vector.print %a: vector<f32>
+  return
+}
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
   %2 = call  @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
   call  @extract_element_0d(%2) : (vector<f32>) -> ()
+
+  %3 = arith.constant dense<42.0> : vector<f32>
+  call  @print_vector_0d(%3) : (vector<f32>) -> ()
+
   return
 }


        


More information about the Mlir-commits mailing list