[Mlir-commits] [mlir] [mlir][vector] Fix invalid IR in `ContractionOpLowering` (PR #78130)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 15 01:20:52 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes `Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
```
* Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
** Insert : 'vector.transpose'(0x5625b3a8cb30)
** Insert : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
} -> failure : pattern failed to match
} -> failure : pattern failed to match
LLVM ERROR: pattern returned failure but IR did change
```
Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`. The `greedy pattern application failed` error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.
---
Full diff: https://github.com/llvm/llvm-project/pull/78130.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+57-41)
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+16-4)
- (removed) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir (-35)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d247..5310b9689a3505 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
}
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
- VectorType lhsType, int reductionDim,
+ VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
- // Unrolling a scalable dimension would be incorrect - bail out.
- if (lhsType.getScalableDims()[reductionDim])
- return failure();
-
- int reductionSize = lhsType.getDimSize(reductionDim);
- assert(reductionSize > 0 &&
- "Reduction dim must be a known static size to allow unrolling");
-
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
@@ -458,6 +450,17 @@ struct UnrolledOuterProductGenerator
return res;
}
+ std::optional<int64_t> getReductionSize(VectorType vecType,
+ int64_t reductionDim) {
+ // Cannot unroll scalable dimension.
+ if (vecType.getScalableDims()[reductionDim])
+ return std::nullopt;
+ int64_t reductionSize = vecType.getDimSize(reductionDim);
+ assert(reductionSize > 0 &&
+ "Reduction dim must be a known static size to allow unrolling");
+ return reductionSize;
+ }
+
/// Two outer parallel, one inner reduction (matmat flavor).
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
@@ -465,42 +468,52 @@ struct UnrolledOuterProductGenerator
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
- Value transposedMask = t(mask, {2, 0, 1});
+
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
- Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tlhs = t(lhs);
+ return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
+ }
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, t(rhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(rhs, t(lhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
- Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value trhs = t(rhs);
+ return outerProd(trhs, t(lhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
+ }
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
return failure();
}
@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
- Value transposedMask = t(mask);
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, t(mask));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, t(mask));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, t(mask));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, t(mask));
return failure();
}
@@ -547,16 +559,20 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2f..5c8527f77e3df0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
%x: vector<2xf32>,
%b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[A]]
+ // CHECK-DAG: vector.transpose %[[MASK]]
+ // CHECK-DAG: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
%x: vector<2xf32>,
%b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[A]]
+ // CHECK-DAG: vector.transpose %[[MASK]]
+ // CHECK-DAG: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK: vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// ============================================================================
// TD sequence
// ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b3..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error at below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
- %arg1: vector<[3]xf32>,
- %arg2: vector<[2]xf32>,
- %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
- : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
- return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %f = transform.structured.match ops{["func.func"]} in %module_op
- : (!transform.any_op) -> !transform.any_op
-
- transform.apply_patterns to %f {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.any_op
- transform.yield
- }
-}
``````````
</details>
https://github.com/llvm/llvm-project/pull/78130
More information about the Mlir-commits
mailing list