[Mlir-commits] [mlir] ad100b3 - [mlir][vector] Fix dominance error in warp vector distribution (#77771)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 12 06:08:17 PST 2024
Author: Matthias Springer
Date: 2024-01-12T15:08:13+01:00
New Revision: ad100b36e728880391df0c3941cdfb1d53517ac7
URL: https://github.com/llvm/llvm-project/commit/ad100b36e728880391df0c3941cdfb1d53517ac7
DIFF: https://github.com/llvm/llvm-project/commit/ad100b36e728880391df0c3941cdfb1d53517ac7.diff
LOG: [mlir][vector] Fix dominance error in warp vector distribution (#77771)
This commit fixes a test in `vector-warp-distribute.mlir` when
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is enabled.
```
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: error: operand #0 does not dominate this use
%1 = vector.extract %0[9] : f32 from vector<64xf32>
^
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: see current operation: %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: operand defined here (op in a child region)
"func.func"() <{function_type = (index) -> f32, sym_name = "vector_extract_1d"}> ({
^bb0(%arg0: index):
%0:2 = "vector.warp_execute_on_lane_0"(%arg0) <{warp_size = 32 : i64}> ({
%7 = "some_def"() : () -> vector<64xf32>
%8 = "arith.constant"() <{value = 9 : index}> : () -> index
%9 = "vector.extractelement"(%7, %8) : (vector<64xf32>, index) -> f32
"vector.yield"(%9, %7) : (f32, vector<64xf32>) -> ()
}) : (index) -> (f32, vector<2xf32>)
%1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
%2 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 mod 2)>}> : (index) -> index
%3 = "vector.extractelement"(%0#1, %2) : (vector<2xf32>, index) -> f32
%4 = "arith.index_cast"(%1) : (index) -> i32
%5 = "arith.constant"() <{value = 32 : i32}> : () -> i32
%6:2 = "gpu.shuffle"(%3, %4, %5) <{mode = #gpu<shuffle_mode idx>}> : (f32, i32, i32) -> (f32, i1)
"func.return"(%6#0) : (f32) -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```
The position at which `vector.extractelement` extracts must also be
distributed. The fix in `WarpOpExtractElement` is similar to
`WarpOpInsertElement`.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b369964bd01c51..9d5ad20d4715b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1321,11 +1321,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
} else {
distributedVecType = extractSrcType;
}
- // Yield source vector from warp op.
+ // Yield source vector and position (if present) from warp op.
+ SmallVector<Value> additionalResults{extractOp.getVector()};
+ SmallVector<Type> additionalResultTypes{distributedVecType};
+ if (static_cast<bool>(extractOp.getPosition())) {
+ additionalResults.push_back(extractOp.getPosition());
+ additionalResultTypes.push_back(extractOp.getPosition().getType());
+ }
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
+ rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1354,14 +1360,16 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
- loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
+ loc, sym0.ceilDiv(elementsPerLane),
+ newWarpOp->getResult(newRetIndices[1]));
// Extract at position: pos % elementsPerLane
Value pos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: rewriter
- .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
- extractOp.getPosition())
+ .create<affine::AffineApplyOp>(
+ loc, sym0 % elementsPerLane,
+ newWarpOp->getResult(newRetIndices[1]))
.getResult();
Value extracted =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
More information about the Mlir-commits
mailing list