[Mlir-commits] [mlir] dc26c03 - [mlir][vector] Add insertOp src shape check for BubbleUpBitCastForStridedSliceInsert
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 10 16:42:37 PST 2022
Author: stanley-nod
Date: 2022-11-10T16:41:59-08:00
New Revision: dc26c030661a763bdc50c759576fc3c34f3c496a
URL: https://github.com/llvm/llvm-project/commit/dc26c030661a763bdc50c759576fc3c34f3c496a
DIFF: https://github.com/llvm/llvm-project/commit/dc26c030661a763bdc50c759576fc3c34f3c496a.diff
LOG: [mlir][vector] Add insertOp src shape check for BubbleUpBitCastForStridedSliceInsert
Not all shape of vectors can be casted into other types, we add a check
to not fold insertOp into bitcast if the shape does not support it.
Examples of unsupported shape castings are f16 vectors to f32 if the
shape is not multiple of 2s. or int8 to int32 if shapes are not multiple
of 4.
Reviewed By: antiagainst, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D137802
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 0bdaf7b56f829..db804748c0548 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2503,6 +2503,14 @@ struct BubbleUpBitCastForStridedSliceInsert
if (rank != insertOp.getDestVectorType().getRank())
return failure();
+ // Requires that shape of insert op src is castable to dstType.
+ unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+ unsigned destinationWidth =
+ castDstType.getElementType().getIntOrFloatBitWidth();
+ unsigned numElements = destinationWidth / sourceWidth;
+ if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
+ return failure();
+
ArrayAttr newOffsets = insertOp.getOffsets();
assert(newOffsets.size() == rank);
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 0c86c7c4b6527..e8f5cf722e876 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -507,3 +507,21 @@ func.func @bubble_up_bitcast_in_strided_slice_insert_
diff erent_rank(%dst: vector
%cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
return %cast: vector<16x4x4xf32>
}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape
+func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> {
+ // CHECK: vector.insert_strided_slice
+ // CHECK-NEXT: vector.bitcast
+ %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
+ %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32>
+ return %cast: vector<1xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape
+func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> {
+ // CHECK: vector.insert_strided_slice
+ // CHECK-NEXT: vector.bitcast
+ %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16>
+ %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+ return %cast: vector<4xf32>
+}
More information about the Mlir-commits
mailing list