[Mlir-commits] [mlir] 73a9d6d - [mlir][linalg] Fix bug in contraction op vectorization with output perm
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 23 08:40:03 PDT 2021
Author: thomasraoux
Date: 2021-07-23T08:39:43-07:00
New Revision: 73a9d6d0e200d7553f925ca0f4caae86dc2b2f67
URL: https://github.com/llvm/llvm-project/commit/73a9d6d0e200d7553f925ca0f4caae86dc2b2f67
DIFF: https://github.com/llvm/llvm-project/commit/73a9d6d0e200d7553f925ca0f4caae86dc2b2f67.diff
LOG: [mlir][linalg] Fix bug in contraction op vectorization with output perm
When the output indexing map has a permutation we need to consider in
the contraction vector type.
Differential Revision: https://reviews.llvm.org/D106469
Added:
Modified:
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index b308bea6ec67b..41cc03735a2ad 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -503,6 +503,20 @@ AffineMap
getProjectedMap(AffineMap map,
const llvm::SmallDenseSet<unsigned> &projectedDimensions);
+/// Apply a permutation from `map` to `source` and return the result.
+template <typename T>
+SmallVector<T> applyPermuationMap(AffineMap map, llvm::ArrayRef<T> source) {
+ assert(map.isProjectedPermutation());
+ assert(map.getNumInputs() == source.size());
+ SmallVector<T> result;
+ result.reserve(map.getNumResults());
+ for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+ unsigned dim = map.getDimPosition(i);
+ result.push_back(source[dim]);
+ }
+ return result;
+}
+
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
return os;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 28c4ae92f7408..243eb621ca46a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -191,7 +191,6 @@ getKindForOp(Operation *reductionOp) {
static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
Value value, OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
- assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
auto vecType = value.getType().dyn_cast<VectorType>();
if (!vecType || vecType.getShape() == targetVectorType.getShape())
return value;
@@ -245,6 +244,9 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap map =
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
+ SmallVector<int64_t> transposeShape =
+ applyPermuationMap(inversePermutation(map), vectorType.getShape());
+ vectorType = VectorType::get(transposeShape, vectorType.getElementType());
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(b, value, vectorType.getShape());
@@ -569,9 +571,16 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
return VectorizationResult{VectorizationStatus::Failure, nullptr};
ArrayRef<int64_t> outShape =
linalgOp.getShape(linalgOp.getOutputOperand(0));
- auto vType = outShape.empty()
- ? op->getResult(0).getType()
- : VectorType::get(outShape, op->getResult(0).getType());
+ Type vType;
+ if (outShape.empty()) {
+ vType = op->getResult(0).getType();
+ } else {
+ SmallVector<int64_t> resultShape = applyPermuationMap(
+ inversePermutation(reindexIndexingMap(
+ linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))),
+ outShape);
+ vType = VectorType::get(resultShape, op->getResult(0).getType());
+ }
auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
// Indexing maps at the time of vector.transfer_read are adjusted to order
// vector dimensions in the same order as the canonical linalg op iteration
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2a99eb6e7063b..77edad70a1adb 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -342,17 +342,6 @@ struct UnrollTransferWritePattern
vector::UnrollVectorOptions options;
};
-template <typename T>
-SmallVector<T> permute(AffineMap map, llvm::ArrayRef<T> source) {
- SmallVector<T> result;
- result.reserve(map.getNumResults());
- for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
- unsigned dim = map.getDimPosition(i);
- result.push_back(source[dim]);
- }
- return result;
-}
-
struct UnrollContractionPattern
: public OpRewritePattern<vector::ContractionOp> {
struct OffsetMapInfo {
@@ -403,7 +392,7 @@ struct UnrollContractionPattern
AffineMap permutationMap,
ArrayRef<int64_t> operandOffets) {
SmallVector<int64_t> operandShape =
- permute(permutationMap, ArrayRef<int64_t>(*targetShape));
+ applyPermuationMap(permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
loc, operand, operandOffets, operandShape, operandStrides);
@@ -412,7 +401,7 @@ struct UnrollContractionPattern
// Extract the new lhs operand.
AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
SmallVector<int64_t> lhsOffets =
- permute(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+ applyPermuationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
// If there is a mask associated to lhs, extract it as well.
if (slicesOperands.size() > 3)
@@ -421,7 +410,7 @@ struct UnrollContractionPattern
// Extract the new rhs operand.
AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
SmallVector<int64_t> rhsOffets =
- permute(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+ applyPermuationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
// If there is a mask associated to rhs, extract it as well.
if (slicesOperands.size() > 4)
@@ -429,7 +418,7 @@ struct UnrollContractionPattern
AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
SmallVector<int64_t> accOffets =
- permute(accPermutationMap, ArrayRef<int64_t>(offsets));
+ applyPermuationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
// If a version of the accumulator has already been computed, use it
// otherwise extract the first version from the original operand.
auto accIt = accCache.find(accOffets);
@@ -439,13 +428,13 @@ struct UnrollContractionPattern
extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
SmallVector<int64_t> dstShape =
- permute(dstAffineMap, ArrayRef<int64_t>(*targetShape));
+ applyPermuationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, contractOp, slicesOperands, targetType);
SmallVector<int64_t> dstOffets =
- permute(dstAffineMap, ArrayRef<int64_t>(offsets));
+ applyPermuationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
// Save the accumulated value untill all the loops are unrolled since
// reduction loop keep updating the accumulator.
accCache[dstOffets] = newOp->getResult(0);
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a286fd630e971..da4f1e99a9391 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -85,6 +85,44 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// -----
+#matmul_transpose_out_trait = {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @generic_output_transpose
+func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+ %C: memref<32x8xf32>) {
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
+ // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
+ // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
+ // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
+ linalg.generic #matmul_transpose_out_trait
+ ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C : memref<32x8xf32>) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b: f32
+ %e = addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+
+// -----
+
#matmul_trait = {
args_in = 2,
args_out = 1,
More information about the Mlir-commits
mailing list