[Mlir-commits] [mlir] 5fbfe2e - [mlir][vector] Add vector.bitcast operation

Thomas Raoux llvmlistbot at llvm.org
Wed Aug 26 14:17:35 PDT 2020


Author: Thomas Raoux
Date: 2020-08-26T14:13:52-07:00
New Revision: 5fbfe2ec4f8baf6a4729f9dc2e4fe16f269921eb

URL: https://github.com/llvm/llvm-project/commit/5fbfe2ec4f8baf6a4729f9dc2e4fe16f269921eb
DIFF: https://github.com/llvm/llvm-project/commit/5fbfe2ec4f8baf6a4729f9dc2e4fe16f269921eb.diff

LOG: [mlir][vector] Add vector.bitcast operation

Based on the RFC discussed here:
https://llvm.discourse.group/t/rfc-vector-standard-add-bitcast-operation/1628/

Adding a vector.bitcast operation that allows casting to a vector of different
element type. The most minor dimension bitwidth must stay unchanged.

Differential Revision: https://reviews.llvm.org/D86580

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 3dc01b3c0914..22fd036df814 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1525,6 +1525,41 @@ def Vector_ShapeCastOp :
   let hasFolder = 1;
 }
 
+def Vector_BitCastOp :
+  Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>,
+    Arguments<(ins AnyVector:$source)>,
+    Results<(outs AnyVector:$result)>{
+  let summary = "bitcast casts between vectors";
+  let description = [{
+    The bitcast operation casts between vectors of the same rank, the minor 1-D
+    vector size is casted to a vector with a 
diff erent element type but same
+    bitwidth.
+
+    Example:
+
+    ```mlir
+    // Example casting to a smaller element type.
+    %1 = vector.bitcast %0 : vector<5x1x4x3xf32> to vector<5x1x4x6xi16>
+
+    // Example casting to a bigger element type.
+    %3 = vector.bitcast %2 : vector<10x12x8xi8> to vector<10x12x2xi32>
+
+    // Example casting to an element type of the same size.
+    %5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return source().getType().cast<VectorType>();
+    }
+    VectorType getResultVectorType() {
+      return getResult().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
+  let hasFolder = 1;
+}
+
 def Vector_TypeCastOp :
   Vector_Op<"type_cast", [NoSideEffect]>,
     Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d3eebc38d5de..7fa62ea34de1 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2300,6 +2300,42 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// VectorBitCastOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(BitCastOp op) {
+  auto sourceVectorType = op.getSourceVectorType();
+  auto resultVectorType = op.getResultVectorType();
+
+  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
+    if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
+      return op.emitOpError("dimension size mismatch at: ") << i;
+  }
+
+  if (sourceVectorType.getElementTypeBitWidth() *
+          sourceVectorType.getShape().back() !=
+      resultVectorType.getElementTypeBitWidth() *
+          resultVectorType.getShape().back())
+    return op.emitOpError(
+        "source/result bitwidth of the minor 1-D vectors must be equal");
+
+  return success();
+}
+
+OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
+  // Nop cast.
+  if (source().getType() == result().getType())
+    return source();
+
+  // Canceling bitcasts.
+  if (auto otherOp = source().getDefiningOp<BitCastOp>())
+    if (result().getType() == otherOp.source().getType())
+      return otherOp.source();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TypeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bd6ef4150a0d..1b1362f94884 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -372,3 +372,16 @@ func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9
   // CHECK: return
   return %1, %2 : vector<4x8xf32>, vector<4x9xf32>
 }
+
+// -----
+
+// CHECK-LABEL: bitcast_folding
+//  CHECK-SAME:   %[[A:.*]]: vector<4x8xf32>
+//  CHECK-SAME:   %[[B:.*]]: vector<2xi32>
+//  CHECK:        return %[[A]], %[[B]] : vector<4x8xf32>, vector<2xi32>
+func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf32>, vector<2xi32>) {
+  %0 = vector.bitcast %I1 : vector<4x8xf32> to vector<4x8xf32>
+  %1 = vector.bitcast %I2 : vector<2xi32> to vector<4xi16>
+  %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32>
+  return %0, %2 : vector<4x8xf32>, vector<2xi32>
+}

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 71d2989661ae..3a081231fc7d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1065,6 +1065,34 @@ func @shape_cast_
diff erent_tuple_sizes(
 
 // -----
 
+func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
+  // expected-error at +1 {{must be vector of any type values}}
+  %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
+}
+
+// -----
+
+func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) {
+  // expected-error at +1 {{op failed to verify that all of {source, result} have same rank}}
+  %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32>
+}
+
+// -----
+
+func @bitcast_shape_mismatch(%arg0 : vector<5x1x3x2xf32>) {
+  // expected-error at +1 {{op dimension size mismatch}}
+  %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x2x3x2xf32>
+}
+
+// -----
+
+func @bitcast_sizemismatch(%arg0 : vector<5x1x3x2xf32>) {
+  // expected-error at +1 {{op source/result bitwidth of the minor 1-D vectors must be equal}}
+  %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x3xf16>
+}
+
+// -----
+
 func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
   // expected-error at +1 {{'vector.reduction' op unknown reduction kind: joho}}
   %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6f9990d5d97c..2a62be94e01b 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -298,6 +298,33 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
   return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
 }
 
+// CHECK-LABEL: @bitcast
+func @bitcast(%arg0 : vector<5x1x3x2xf32>,
+                 %arg1 : vector<8x1xi32>,
+                 %arg2 : vector<16x1x8xi8>)
+  -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>) {
+
+  // CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
+  %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
+
+  // CHECK-NEXT: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x8xi8>
+  %1 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x8xi8>
+
+  // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x1xi32> to vector<8x4xi8>
+  %2 = vector.bitcast %arg1 : vector<8x1xi32> to vector<8x4xi8>
+
+  // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x1xi32> to vector<8x1xf32>
+  %3 = vector.bitcast %arg1 : vector<8x1xi32> to vector<8x1xf32>
+
+  // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x2xi32>
+  %4 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x2xi32>
+
+  // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x4xi16>
+  %5 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x4xi16>
+
+  return %0, %1, %2, %3, %4, %5 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>
+}
+
 // CHECK-LABEL: @vector_fma
 func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
   // CHECK: vector.fma %{{.*}} : vector<8xf32>


        


More information about the Mlir-commits mailing list