[Mlir-commits] [mlir] 069d7d7 - [mlir][vector] Fix crash in extractelement vec distribution
Thomas Raoux
llvmlistbot at llvm.org
Tue Jan 10 18:35:50 PST 2023
Author: Thomas Raoux
Date: 2023-01-11T02:35:12Z
New Revision: 069d7d7e4868dd7817b8b0c6858ac2334c1a4d89
URL: https://github.com/llvm/llvm-project/commit/069d7d7e4868dd7817b8b0c6858ac2334c1a4d89
DIFF: https://github.com/llvm/llvm-project/commit/069d7d7e4868dd7817b8b0c6858ac2334c1a4d89.diff
LOG: [mlir][vector] Fix crash in extractelement vec distribution
Prevent creating a vector of size 0 that would fail verifier.
Vector 1d with a single element should be treated like 0d vectors.
Differential Revision: https://reviews.llvm.org/D141452
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 07e608ab4b10..c8b0fc48da06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -995,19 +995,20 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
VectorType extractSrcType = extractOp.getVectorType();
- bool is0dExtract = extractSrcType.getRank() == 0;
+ bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
Type elType = extractSrcType.getElementType();
VectorType distributedVecType;
- if (!is0dExtract) {
+ if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
"expected that extractelement src rank is 0 or 1");
+ if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
+ return failure();
int64_t elementsPerLane =
extractSrcType.getShape()[0] / warpOp.getWarpSize();
distributedVecType = VectorType::get({elementsPerLane}, elType);
} else {
distributedVecType = extractSrcType;
}
-
// Yield source vector from warp op.
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
@@ -1019,9 +1020,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// 0d extract: The new warp op broadcasts the source vector to all lanes.
// All lanes extract the scalar.
- if (is0dExtract) {
- Value newExtract =
- rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ if (is0dOrVec1Extract) {
+ Value newExtract;
+ if (extractSrcType.getRank() == 1) {
+ newExtract = rewriter.create<vector::ExtractElementOp>(
+ loc, distributedVec,
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ } else {
+ newExtract =
+ rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ }
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 2dd54771d897..b5087feaed02 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -761,6 +761,26 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
// -----
+// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
+// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32>
+// CHECK-PROP: vector.yield %[[V]] : vector<1xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32>
+// CHECK-PROP: return %[[E]] : f32
+func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (vector<1xf32>)
+ %c0 = arith.constant 0 : index
+ %1 = vector.extractelement %0[%c0 : index] : vector<1xf32>
+ vector.yield %1 : f32
+ }
+ return %r : f32
+}
+
+// -----
+
// CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
// CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
// CHECK-PROP-LABEL: func.func @vector_extractelement_1d(
More information about the Mlir-commits
mailing list