[Mlir-commits] [mlir] 1423e8b - [mlir][Vector] Support 0-D vectors in `BitCastOp`

Nicolas Vasilache llvmlistbot at llvm.org
Fri Dec 3 00:57:44 PST 2021


Author: Michal Terepeta
Date: 2021-12-03T08:55:59Z
New Revision: 1423e8bf5dda75877c0414dd26d024fd770d71fb

URL: https://github.com/llvm/llvm-project/commit/1423e8bf5dda75877c0414dd26d024fd770d71fb
DIFF: https://github.com/llvm/llvm-project/commit/1423e8bf5dda75877c0414dd26d024fd770d71fb.diff

LOG: [mlir][Vector] Support 0-D vectors in `BitCastOp`

The implementation only allows to bit-cast between two 0-D vectors. We could
probably support casting from/to vectors like `vector<1xf32>`, but I wasn't
convinced that this would be important and it would require breaking the
invariant that `BitCastOp` works only on vectors with equal rank.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114854

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 8eaf785319578..74edc5fe5f9b9 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -675,7 +675,7 @@ def Vector_InsertElementOp :
     position and inserts the source into the destination at the proper position.
 
     Note that this instruction resembles vector.insert, but is restricted to 0-D
-    and 1-D vectors and relaxed to dynamic indices. 
+    and 1-D vectors and relaxed to dynamic indices.
 
     It is meant to be closer to LLVM's version:
     https://llvm.org/docs/LangRef.html#insertelement-instruction
@@ -2025,13 +2025,14 @@ def Vector_ShapeCastOp :
 
 def Vector_BitCastOp :
   Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>,
-    Arguments<(ins AnyVector:$source)>,
-    Results<(outs AnyVector:$result)>{
+    Arguments<(ins AnyVectorOfAnyRank:$source)>,
+    Results<(outs AnyVectorOfAnyRank:$result)>{
   let summary = "bitcast casts between vectors";
   let description = [{
     The bitcast operation casts between vectors of the same rank, the minor 1-D
     vector size is casted to a vector with a 
diff erent element type but same
-    bitwidth.
+    bitwidth. In case of 0-D vectors, the bitwidth of element types must be
+    equal.
 
     Example:
 
@@ -2044,6 +2045,9 @@ def Vector_BitCastOp :
 
     // Example casting to an element type of the same size.
     %5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32>
+
+    // Example casting of 0-D vectors.
+    %7 = vector.bitcast %6 : vector<f32> to vector<i32>
     ```
   }];
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index bc42922a44858..9b4dce458b7a3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -121,9 +121,9 @@ class VectorBitCastOpConversion
   LogicalResult
   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Only 1-D vectors can be lowered to LLVM.
-    VectorType resultTy = bitCastOp.getType();
-    if (resultTy.getRank() != 1)
+    // Only 0-D and 1-D vectors can be lowered to LLVM.
+    VectorType resultTy = bitCastOp.getResultVectorType();
+    if (resultTy.getRank() > 1)
       return failure();
     Type newResultTy = typeConverter->convertType(resultTy);
     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 859067b2bffe8..76fcb97898e54 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3702,12 +3702,20 @@ static LogicalResult verify(BitCastOp op) {
   }
 
   DataLayout dataLayout = DataLayout::closest(op);
-  if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
-          sourceVectorType.getShape().back() !=
-      dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
-          resultVectorType.getShape().back())
+  auto sourceElementBits =
+      dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
+  auto resultElementBits =
+      dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
+
+  if (sourceVectorType.getRank() == 0) {
+    if (sourceElementBits != resultElementBits)
+      return op.emitOpError("source/result bitwidth of the 0-D vector element "
+                            "types must be equal");
+  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
+             resultElementBits * resultVectorType.getShape().back()) {
     return op.emitOpError(
         "source/result bitwidth of the minor 1-D vectors must be equal");
+  }
 
   return success();
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ce81c4e36bb63..0c21cf9eecaa1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1,6 +1,20 @@
 // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s
 
 
+func @bitcast_f32_to_i32_vector_0d(%input: vector<f32>) -> vector<i32> {
+  %0 = vector.bitcast %input : vector<f32> to vector<i32>
+  return %0 : vector<i32>
+}
+
+// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d
+// CHECK-SAME:  %[[input:.*]]: vector<f32>
+// CHECK:       %[[vec_f32_1d:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<f32> to vector<1xf32>
+// CHECK:       %[[vec_i32_1d:.*]] = llvm.bitcast %[[vec_f32_1d]] : vector<1xf32> to vector<1xi32>
+// CHECK:       %[[vec_i32_0d:.*]] = builtin.unrealized_conversion_cast %[[vec_i32_1d]] : vector<1xi32> to vector<i32>
+// CHECK:       return %[[vec_i32_0d]] : vector<i32>
+
+// -----
+
 func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> {
   %0 = vector.bitcast %input : vector<16xf32> to vector<16xi32>
   return %0 : vector<16xi32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 7902976cce59c..fb69798ee1054 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1014,6 +1014,20 @@ func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
 
 // -----
 
+func @bitcast_rank_mismatch_to_0d(%arg0 : vector<1xf32>) {
+  // expected-error at +1 {{op failed to verify that all of {source, result} have same rank}}
+  %0 = vector.bitcast %arg0 : vector<1xf32> to vector<f32>
+}
+
+// -----
+
+func @bitcast_rank_mismatch_from_0d(%arg0 : vector<f32>) {
+  // expected-error at +1 {{op failed to verify that all of {source, result} have same rank}}
+  %0 = vector.bitcast %arg0 : vector<f32> to vector<1xf32>
+}
+
+// -----
+
 func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) {
   // expected-error at +1 {{op failed to verify that all of {source, result} have same rank}}
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 11bc141556e32..2bd0e13f05e4e 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -432,8 +432,9 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
 func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,
                  %arg2 : vector<16x1x8xi8>,
-                 %arg3 : vector<8x2x1xindex>)
-  -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>) {
+                 %arg3 : vector<8x2x1xindex>,
+                 %arg4 : vector<f32>)
+  -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector<i32>) {
 
   // CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
@@ -459,7 +460,10 @@ func @bitcast(%arg0 : vector<5x1x3x2xf32>,
   // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x2x1xindex> to vector<8x2x2xf32>
   %7 = vector.bitcast %arg3 : vector<8x2x1xindex> to vector<8x2x2xf32>
 
-  return %0, %1, %2, %3, %4, %5, %6, %7 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>
+  // CHECK: vector.bitcast %{{.*}} : vector<f32> to vector<i32>
+  %8 = vector.bitcast %arg4 : vector<f32> to vector<i32>
+
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector<i32>
 }
 
 // CHECK-LABEL: @vector_fma

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 8e69d658612b8..74bbadaac520b 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -55,6 +55,19 @@ func @broadcast_0d(%a: f32) {
   return
 }
 
+func @bitcast_0d() {
+  %0 = arith.constant 42 : i32
+  %1 = arith.constant dense<0> : vector<i32>
+  %2 = vector.insertelement %0, %1[] : vector<i32>
+  %3 = vector.bitcast %2 : vector<i32> to vector<f32>
+  %4 = vector.extractelement %3[] : vector<f32>
+  %5 = arith.bitcast %4 : f32 to i32
+  // CHECK: 42
+  vector.print %5: i32
+  return
+}
+
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -68,5 +81,7 @@ func @entry() {
   call  @splat_0d(%4) : (f32) -> ()
   call  @broadcast_0d(%4) : (f32) -> ()
 
+  call  @bitcast_0d() : () -> ()
+
   return
 }


        


More information about the Mlir-commits mailing list