[Mlir-commits] [mlir] [mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct (PR #68400)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 6 02:56:10 PDT 2023
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/68400
This patch constrains the patterns for converting `vector.contract` to
`vector.outerproduct` so that
* the reduction dimension is _not unrolled_ if the corresponding
dimension is scalable.
This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for `vector.contract` (K is the size f the reduction
dimension):
```
// K times
%lhs = vector.extract %LHS[0]
%rhs = vector.extract %RHS[0]
vector.outerproduct %lhs, %rhs
%lhs = vector.extract %LHS[1]
%rhs = vector.extract %RHS[1]
vector.outerproduct %lhs, %rhs
...
```
we should be generating a `for` loop like the following:
```
scf.for %k = 0 to K step 1
%lhs = vector.extract LHS[%k]
%rhs = vector.extract RHS[%k]
vector.outerproduct %lhs, %rhs
```
However, the lowering of `vector.extract` of vector slices with dynamic
indices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).
In order to document unsupported cases, a dedicated test file is added:
* "vector-contract-to-outerproduct-transforms-unsupported.mlir"
This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.
>From 76c021faddbdf06f80873555a9691be384206da9 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 6 Oct 2023 09:31:53 +0000
Subject: [PATCH] [mlir][vector] Constrain patterns: vector.contract ->
vector.outerproduct
This patch constrains the patterns for converting `vector.contract` to
`vector.outerproduct` so that
* the reduction dimension is _not unrolled_ if the corresponding
dimension is scalable.
This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for `vector.contract` (K is the size f the reduction
dimension):
```
// K times
%lhs = vector.extract %LHS[0]
%rhs = vector.extract %RHS[0]
vector.outerproduct %lhs, %rhs
%lhs = vector.extract %LHS[1]
%rhs = vector.extract %RHS[1]
vector.outerproduct %lhs, %rhs
...
```
we should be generating a `for` loop like the following:
```
scf.for %k = 0 to K step 1
%lhs = vector.extract LHS[%k]
%rhs = vector.extract RHS[%k]
vector.outerproduct %lhs, %rhs
```
However, the lowering of `vector.extract` of vector slices with dynamic
indices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).
In order to document unsupported cases, a dedicated test file is added:
* "vector-contract-to-outerproduct-transforms-unsupported.mlir"
This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.
---
.../Vector/Transforms/LowerVectorContract.cpp | 57 +++++++++++++------
...o-outerproduct-transforms-unsupported.mlir | 33 +++++++++++
...r-contract-to-outerproduct-transforms.mlir | 53 +++++++++--------
3 files changed, 102 insertions(+), 41 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..c1cc0d7c64de264 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
- FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+ FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, VectorType lhsType,
+ int reductionDim,
std::optional<Value> maybeMask = std::nullopt) {
- assert(reductionSize > 0);
+ // Unrolling a scalable dimension would be incorrect - bail out.
+ if (lhsType.getScalableDims()[reductionDim])
+ return failure();
+
+ int reductionSize = lhsType.getDimSize(reductionDim);
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
@@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator
Value transposedMask = t(mask, {2, 0, 1});
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+ return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+ transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+ return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
transposedMask);
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
+ return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
+ transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+ return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
transposedMask);
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
return failure();
}
@@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+ return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+ transposedMask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+ return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+ transposedMask);
return failure();
}
@@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
+ return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
+ return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
+ return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
+ return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
return failure();
}
@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
+ // Unrolling a scalable dimension would be incorrect - bail out.
+ if (lhsType.getScalableDims()[lhsIndex])
+ return failure();
dimSize = lhsType.getDimSize(lhsIndex);
} else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
+ // Unrolling a scalable dimension would be incorrect - bail out.
+ if (rhsType.getScalableDims()[rhsIndex])
+ return failure();
dimSize = rhsType.getDimSize(rhsIndex);
}
if (iterIndex < 0)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
new file mode 100644
index 000000000000000..a955250107d73d7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+#matvec_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j)>,
+ affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"]
+}
+
+// Unrolling scalable reduction dim is not supported - bail out
+
+// expected-error at below {{greedy pattern application failed}}
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..8ee0a35717ce87a 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -31,19 +31,19 @@
}
// CHECK-LABEL: func.func @masked_extract_contract2(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<2xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+// CHECK-LABEL: func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
More information about the Mlir-commits
mailing list