[Mlir-commits] [mlir] c0a354d - [mlir][vector] Fix invalid IR in `ContractionOpLowering` (#78130)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 16 00:40:29 PST 2024
Author: Matthias Springer
Date: 2024-01-16T09:40:24+01:00
New Revision: c0a354dfabfd1534bc6f992b242e5d0ea043120d
URL: https://github.com/llvm/llvm-project/commit/c0a354dfabfd1534bc6f992b242e5d0ea043120d
DIFF: https://github.com/llvm/llvm-project/commit/c0a354dfabfd1534bc6f992b242e5d0ea043120d.diff
LOG: [mlir][vector] Fix invalid IR in `ContractionOpLowering` (#78130)
If a rewrite pattern returns "failure", it must not have modified the
IR. This commit fixes
`Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir`
when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
```
* Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
** Insert : 'vector.transpose'(0x5625b3a8cb30)
** Insert : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
} -> failure : pattern failed to match
} -> failure : pattern failed to match
LLVM ERROR: pattern returned failure but IR did change
```
Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is
merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`.
The `greedy pattern application failed` error is not longer produced.
This error indicates that the greedy pattern rewrite did not
convergence; it does not mean that a pattern could not be applied.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
Removed:
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d2478..446eb853d2e92d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
}
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
- VectorType lhsType, int reductionDim,
+ VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
- // Unrolling a scalable dimension would be incorrect - bail out.
- if (lhsType.getScalableDims()[reductionDim])
- return failure();
-
- int reductionSize = lhsType.getDimSize(reductionDim);
- assert(reductionSize > 0 &&
- "Reduction dim must be a known static size to allow unrolling");
-
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
@@ -458,6 +450,20 @@ struct UnrolledOuterProductGenerator
return res;
}
+ /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
+ /// dimension `reductionDim`. If the dimension is a scalable dimension,
+ /// returns "nullopt".
+ std::optional<int64_t> getReductionSize(VectorType vecType,
+ int64_t reductionDim) {
+ // Cannot unroll scalable dimension.
+ if (vecType.getScalableDims()[reductionDim])
+ return std::nullopt;
+ int64_t reductionSize = vecType.getDimSize(reductionDim);
+ assert(reductionSize > 0 &&
+ "Reduction dim must be a known static size to allow unrolling");
+ return reductionSize;
+ }
+
/// Two outer parallel, one inner reduction (matmat flavor).
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
@@ -465,42 +471,72 @@ struct UnrolledOuterProductGenerator
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
- Value transposedMask = t(mask, {2, 0, 1});
+
// Classical row-major matmul: Just permute the lhs.
- if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (layout({{m, k}, {k, n}, {m, n}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ // Note: `t` creates new IR. It must be nested within this `if` check
+ // so that no IR is created when then pattern returns "failure".
+ Value tLhs = t(lhs);
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
- Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tLhs = t(lhs);
+ Value tRhs = t(rhs);
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
+ }
}
// No need to permute anything.
- if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (layout({{k, m}, {k, n}, {m, n}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// Just permute the rhs.
- if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (layout({{k, m}, {n, k}, {m, n}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tRhs = t(rhs);
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
- if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (layout({{m, k}, {k, n}, {n, m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tLhs = t(lhs);
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
- Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tRhs = t(rhs);
+ Value tLhs = t(lhs);
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
+ if (layout({{k, m}, {k, n}, {n, m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
+ if (layout({{k, m}, {n, k}, {n, m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tRhs = t(rhs);
+ Value tMask = t(mask, {2, 0, 1});
+ return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
+ }
}
- if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
- if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
return failure();
}
@@ -514,24 +550,37 @@ struct UnrolledOuterProductGenerator
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
- Value transposedMask = t(mask);
// Case mat-vec: transpose.
- if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (layout({{m, k}, {k}, {m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tLhs = t(lhs);
+ Value tMask = t(mask);
+ return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// Case mat-trans-vec: ready to go.
- if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (layout({{k, m}, {k}, {m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tMask = t(mask);
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// Case vec-mat: swap and transpose.
- if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (layout({{k}, {m, k}, {m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tRhs = t(rhs);
+ Value tMask = t(mask);
+ return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
// Case vec-mat-trans: swap and ready to go.
- if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (layout({{k}, {k, m}, {m}})) {
+ if (auto reductionSize = getReductionSize(lhsType, 0)) {
+ Value tMask = t(mask);
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
+ }
+ }
return failure();
}
@@ -547,16 +596,20 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2fe..412e95bede3a7cf 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
%x: vector<2xf32>,
%b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
// CHECK: vector.transpose %[[A]]
+ // CHECK: vector.transpose %[[MASK]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
%x: vector<2xf32>,
%b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
// CHECK: vector.transpose %[[A]]
+ // CHECK: vector.transpose %[[MASK]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK: vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// ============================================================================
// TD sequence
// ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b37..000000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error at below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
- %arg1: vector<[3]xf32>,
- %arg2: vector<[2]xf32>,
- %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
- : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
- return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %f = transform.structured.match ops{["func.func"]} in %module_op
- : (!transform.any_op) -> !transform.any_op
-
- transform.apply_patterns to %f {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.any_op
- transform.yield
- }
-}
More information about the Mlir-commits
mailing list