[Mlir-commits] [mlir] ad100b3 - [mlir][vector] Fix dominance error in warp vector distribution (#77771)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 12 06:08:17 PST 2024


Author: Matthias Springer
Date: 2024-01-12T15:08:13+01:00
New Revision: ad100b36e728880391df0c3941cdfb1d53517ac7

URL: https://github.com/llvm/llvm-project/commit/ad100b36e728880391df0c3941cdfb1d53517ac7
DIFF: https://github.com/llvm/llvm-project/commit/ad100b36e728880391df0c3941cdfb1d53517ac7.diff

LOG: [mlir][vector] Fix dominance error in warp vector distribution (#77771)

This commit fixes a test in `vector-warp-distribute.mlir` when
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is enabled.

```
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: error: operand #0 does not dominate this use
    %1 = vector.extract %0[9] : f32 from vector<64xf32>
         ^
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: see current operation: %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: operand defined here (op in a child region)
"func.func"() <{function_type = (index) -> f32, sym_name = "vector_extract_1d"}> ({
^bb0(%arg0: index):
  %0:2 = "vector.warp_execute_on_lane_0"(%arg0) <{warp_size = 32 : i64}> ({
    %7 = "some_def"() : () -> vector<64xf32>
    %8 = "arith.constant"() <{value = 9 : index}> : () -> index
    %9 = "vector.extractelement"(%7, %8) : (vector<64xf32>, index) -> f32
    "vector.yield"(%9, %7) : (f32, vector<64xf32>) -> ()
  }) : (index) -> (f32, vector<2xf32>)
  %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
  %2 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 mod 2)>}> : (index) -> index
  %3 = "vector.extractelement"(%0#1, %2) : (vector<2xf32>, index) -> f32
  %4 = "arith.index_cast"(%1) : (index) -> i32
  %5 = "arith.constant"() <{value = 32 : i32}> : () -> i32
  %6:2 = "gpu.shuffle"(%3, %4, %5) <{mode = #gpu<shuffle_mode idx>}> : (f32, i32, i32) -> (f32, i1)
  "func.return"(%6#0) : (f32) -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```

The position at which `vector.extractelement` extracts must also be
distributed. The fix in `WarpOpExtractElement` is similar to
`WarpOpInsertElement`.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b369964bd01c51..9d5ad20d4715b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1321,11 +1321,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     } else {
       distributedVecType = extractSrcType;
     }
-    // Yield source vector from warp op.
+    // Yield source vector and position (if present) from warp op.
+    SmallVector<Value> additionalResults{extractOp.getVector()};
+    SmallVector<Type> additionalResultTypes{distributedVecType};
+    if (static_cast<bool>(extractOp.getPosition())) {
+      additionalResults.push_back(extractOp.getPosition());
+      additionalResultTypes.push_back(extractOp.getPosition().getType());
+    }
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
+        rewriter, warpOp, additionalResults, additionalResultTypes,
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1354,14 +1360,16 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
     // tid of extracting thread: pos / elementsPerLane
     Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
-        loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
+        loc, sym0.ceilDiv(elementsPerLane),
+        newWarpOp->getResult(newRetIndices[1]));
     // Extract at position: pos % elementsPerLane
     Value pos =
         elementsPerLane == 1
             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
             : rewriter
-                  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
-                                                 extractOp.getPosition())
+                  .create<affine::AffineApplyOp>(
+                      loc, sym0 % elementsPerLane,
+                      newWarpOp->getResult(newRetIndices[1]))
                   .getResult();
     Value extracted =
         rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);


        


More information about the Mlir-commits mailing list