[Mlir-commits] [mlir] 73a9d6d - [mlir][linalg] Fix bug in contraction op vectorization with output perm

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 23 08:40:03 PDT 2021


Author: thomasraoux
Date: 2021-07-23T08:39:43-07:00
New Revision: 73a9d6d0e200d7553f925ca0f4caae86dc2b2f67

URL: https://github.com/llvm/llvm-project/commit/73a9d6d0e200d7553f925ca0f4caae86dc2b2f67
DIFF: https://github.com/llvm/llvm-project/commit/73a9d6d0e200d7553f925ca0f4caae86dc2b2f67.diff

LOG: [mlir][linalg] Fix bug in contraction op vectorization with output perm

When the output indexing map has a permutation we need to consider in
the contraction vector type.

Differential Revision: https://reviews.llvm.org/D106469

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index b308bea6ec67b..41cc03735a2ad 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -503,6 +503,20 @@ AffineMap
 getProjectedMap(AffineMap map,
                 const llvm::SmallDenseSet<unsigned> &projectedDimensions);
 
+/// Apply a permutation from `map` to `source` and return the result.
+template <typename T>
+SmallVector<T> applyPermuationMap(AffineMap map, llvm::ArrayRef<T> source) {
+  assert(map.isProjectedPermutation());
+  assert(map.getNumInputs() == source.size());
+  SmallVector<T> result;
+  result.reserve(map.getNumResults());
+  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+    unsigned dim = map.getDimPosition(i);
+    result.push_back(source[dim]);
+  }
+  return result;
+}
+
 inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
   map.print(os);
   return os;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 28c4ae92f7408..243eb621ca46a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -191,7 +191,6 @@ getKindForOp(Operation *reductionOp) {
 static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
                             Value value, OpOperand *outputOperand) {
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
-  assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
   auto vecType = value.getType().dyn_cast<VectorType>();
   if (!vecType || vecType.getShape() == targetVectorType.getShape())
     return value;
@@ -245,6 +244,9 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
     auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
     AffineMap map =
         reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
+    SmallVector<int64_t> transposeShape =
+        applyPermuationMap(inversePermutation(map), vectorType.getShape());
+    vectorType = VectorType::get(transposeShape, vectorType.getElementType());
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                b.create<ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(b, value, vectorType.getShape());
@@ -569,9 +571,16 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
       return VectorizationResult{VectorizationStatus::Failure, nullptr};
     ArrayRef<int64_t> outShape =
         linalgOp.getShape(linalgOp.getOutputOperand(0));
-    auto vType = outShape.empty()
-                     ? op->getResult(0).getType()
-                     : VectorType::get(outShape, op->getResult(0).getType());
+    Type vType;
+    if (outShape.empty()) {
+      vType = op->getResult(0).getType();
+    } else {
+      SmallVector<int64_t> resultShape = applyPermuationMap(
+          inversePermutation(reindexIndexingMap(
+              linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))),
+          outShape);
+      vType = VectorType::get(resultShape, op->getResult(0).getType());
+    }
     auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
     // Indexing maps at the time of vector.transfer_read are adjusted to order
     // vector dimensions in the same order as the canonical linalg op iteration

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2a99eb6e7063b..77edad70a1adb 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -342,17 +342,6 @@ struct UnrollTransferWritePattern
   vector::UnrollVectorOptions options;
 };
 
-template <typename T>
-SmallVector<T> permute(AffineMap map, llvm::ArrayRef<T> source) {
-  SmallVector<T> result;
-  result.reserve(map.getNumResults());
-  for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
-    unsigned dim = map.getDimPosition(i);
-    result.push_back(source[dim]);
-  }
-  return result;
-}
-
 struct UnrollContractionPattern
     : public OpRewritePattern<vector::ContractionOp> {
   struct OffsetMapInfo {
@@ -403,7 +392,7 @@ struct UnrollContractionPattern
                                 AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffets) {
         SmallVector<int64_t> operandShape =
-            permute(permutationMap, ArrayRef<int64_t>(*targetShape));
+            applyPermuationMap(permutationMap, ArrayRef<int64_t>(*targetShape));
         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
             loc, operand, operandOffets, operandShape, operandStrides);
@@ -412,7 +401,7 @@ struct UnrollContractionPattern
       // Extract the new lhs operand.
       AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
       SmallVector<int64_t> lhsOffets =
-          permute(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+          applyPermuationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
       extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
       // If there is a mask associated to lhs, extract it as well.
       if (slicesOperands.size() > 3)
@@ -421,7 +410,7 @@ struct UnrollContractionPattern
       // Extract the new rhs operand.
       AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
       SmallVector<int64_t> rhsOffets =
-          permute(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+          applyPermuationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
       extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
       // If there is a mask associated to rhs, extract it as well.
       if (slicesOperands.size() > 4)
@@ -429,7 +418,7 @@ struct UnrollContractionPattern
 
       AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
       SmallVector<int64_t> accOffets =
-          permute(accPermutationMap, ArrayRef<int64_t>(offsets));
+          applyPermuationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
       // If a version of the accumulator has already been computed, use it
       // otherwise extract the first version from the original operand.
       auto accIt = accCache.find(accOffets);
@@ -439,13 +428,13 @@ struct UnrollContractionPattern
         extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
 
       SmallVector<int64_t> dstShape =
-          permute(dstAffineMap, ArrayRef<int64_t>(*targetShape));
+          applyPermuationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
       Operation *newOp = cloneOpWithOperandsAndTypes(
           rewriter, loc, contractOp, slicesOperands, targetType);
 
       SmallVector<int64_t> dstOffets =
-          permute(dstAffineMap, ArrayRef<int64_t>(offsets));
+          applyPermuationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
       // Save the accumulated value untill all the loops are unrolled since
       // reduction loop keep updating the accumulator.
       accCache[dstOffets] = newOp->getResult(0);

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a286fd630e971..da4f1e99a9391 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -85,6 +85,44 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 
 // -----
 
+#matmul_transpose_out_trait = {
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (n, m)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$trans_2d:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @generic_output_transpose
+func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+                         %C: memref<32x8xf32>) {
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
+  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
+  //  CHECK-SAME:   vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
+  //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
+  linalg.generic #matmul_transpose_out_trait
+    ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
+   outs(%C : memref<32x8xf32>) {
+    ^bb(%a: f32, %b: f32, %c: f32) :
+      %d = mulf %a, %b: f32
+      %e = addf %c, %d: f32
+      linalg.yield %e : f32
+  }
+  return
+}
+
+// -----
+
 #matmul_trait = {
   args_in = 2,
   args_out = 1,


        


More information about the Mlir-commits mailing list