[Mlir-commits] [mlir] 069d7d7 - [mlir][vector] Fix crash in extractelement vec distribution

Thomas Raoux llvmlistbot at llvm.org
Tue Jan 10 18:35:50 PST 2023


Author: Thomas Raoux
Date: 2023-01-11T02:35:12Z
New Revision: 069d7d7e4868dd7817b8b0c6858ac2334c1a4d89

URL: https://github.com/llvm/llvm-project/commit/069d7d7e4868dd7817b8b0c6858ac2334c1a4d89
DIFF: https://github.com/llvm/llvm-project/commit/069d7d7e4868dd7817b8b0c6858ac2334c1a4d89.diff

LOG: [mlir][vector] Fix crash in extractelement vec distribution

Prevent creating a vector of size 0 that would fail verifier.
Vector 1d with a single element should be treated like 0d vectors.

Differential Revision: https://reviews.llvm.org/D141452

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 07e608ab4b10..c8b0fc48da06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -995,19 +995,20 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
     VectorType extractSrcType = extractOp.getVectorType();
-    bool is0dExtract = extractSrcType.getRank() == 0;
+    bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
     Type elType = extractSrcType.getElementType();
     VectorType distributedVecType;
-    if (!is0dExtract) {
+    if (!is0dOrVec1Extract) {
       assert(extractSrcType.getRank() == 1 &&
              "expected that extractelement src rank is 0 or 1");
+      if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
+        return failure();
       int64_t elementsPerLane =
           extractSrcType.getShape()[0] / warpOp.getWarpSize();
       distributedVecType = VectorType::get({elementsPerLane}, elType);
     } else {
       distributedVecType = extractSrcType;
     }
-
     // Yield source vector from warp op.
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
@@ -1019,9 +1020,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
     // 0d extract: The new warp op broadcasts the source vector to all lanes.
     // All lanes extract the scalar.
-    if (is0dExtract) {
-      Value newExtract =
-          rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+    if (is0dOrVec1Extract) {
+      Value newExtract;
+      if (extractSrcType.getRank() == 1) {
+        newExtract = rewriter.create<vector::ExtractElementOp>(
+            loc, distributedVec,
+            rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+      } else {
+        newExtract =
+            rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+      }
       newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
       return success();
     }

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 2dd54771d897..b5087feaed02 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -761,6 +761,26 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
 
 // -----
 
+// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<1xf32>
+//       CHECK-PROP:     vector.yield %[[V]] : vector<1xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32>
+//       CHECK-PROP:   return %[[E]] : f32
+func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<1xf32>)
+    %c0 = arith.constant 0 : index
+    %1 = vector.extractelement %0[%c0 : index] : vector<1xf32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
+
+// -----
+
 //       CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
 //       CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
 // CHECK-PROP-LABEL: func.func @vector_extractelement_1d(


        


More information about the Mlir-commits mailing list