[Mlir-commits] [mlir] 93d0ade - [mlir][linalg] Remove special case for contraction vectorization

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 21 14:11:39 PDT 2021


Author: thomasraoux
Date: 2021-10-21T14:10:54-07:00
New Revision: 93d0ade17c2da810ee6e0d747c3a594b8bfd0c12

URL: https://github.com/llvm/llvm-project/commit/93d0ade17c2da810ee6e0d747c3a594b8bfd0c12
DIFF: https://github.com/llvm/llvm-project/commit/93d0ade17c2da810ee6e0d747c3a594b8bfd0c12.diff

LOG: [mlir][linalg] Remove special case for contraction vectorization

Handle contraction op like all the other generic op reductions. This
simpifies the code. We now rely on contractionOp canonicalization to
keep the same code quality.

Differential Revision: https://reviews.llvm.org/D112171

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 617ea0af3c965..391c21bcd1386 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -189,6 +189,9 @@ struct LinalgStrategyVectorizePass
       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
                                                             filter, options);
     }
+    vector::populateVectorTransferPermutationMapLoweringPatterns(
+        vectorizationPatterns);
+    vector::populateVetorReductionToContractPatterns(vectorizationPatterns);
     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
                               linalg::LinalgCopyVTWForwardingPattern>(
         funcOp.getContext(), /*benefit=*/2);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bff12e6f5c868..b7520f1a62fa3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -45,9 +45,6 @@ using llvm::dbgs;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
-// Forward declarations.
-static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
-                                          SmallVectorImpl<Value> &newResults);
 static FailureOr<Operation *>
 vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
 
@@ -495,10 +492,9 @@ static bool isElementwise(Operation *op) {
 /// the absence of good canonicalizations, the amount of work increases.
 /// This is not deemed a problem as we expect canonicalizations and foldings to
 /// aggressively clean up the useless work.
-LogicalResult vectorizeAsLinalgGeneric(
-    OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
-    bool broadcastToMaximalCommonShape = false,
-    ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
+static LogicalResult
+vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
+                         SmallVectorImpl<Value> &newResults) {
   Block *block = linalgOp.getBlock();
 
   // 2. Values defined above the region can only be broadcast for now. Make them
@@ -530,8 +526,7 @@ LogicalResult vectorizeAsLinalgGeneric(
     if (linalgOp.getShape(opOperand).empty()) {
       readType = bbarg.getType();
     } else {
-      if (broadcastToMaximalCommonShape &&
-          opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
+      if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
         map = inverseAndBroadcastProjectedPermuation(
             linalgOp.getTiedIndexingMap(opOperand));
         readType = VectorType::get(commonVectorShape,
@@ -549,7 +544,7 @@ LogicalResult vectorizeAsLinalgGeneric(
     bvm.map(opOperand->get(), readValue);
   }
 
-  auto hooks = llvm::to_vector<4>(customVectorizationHooks);
+  SmallVector<CustomVectorizationHook> hooks;
   // 4a. Register CustomVectorizationHook for yieldOp.
   CustomVectorizationHook vectorizeYield =
       [&](Operation *op,
@@ -587,61 +582,6 @@ LogicalResult vectorizeAsLinalgGeneric(
 /// This helper is needed atm because the truly generic implementation requires
 /// good vector.multi_reduce folding patterns that are currently NYI.
 // TODO: drop reliance on a specific pattern.
-static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
-                                          SmallVectorImpl<Value> &newResults) {
-  assert(isaContractionOpInterface(linalgOp) &&
-         "expected vectorizeContraction preconditions to be met");
-  Location loc = linalgOp.getLoc();
-  // Vectorize other ops as vector contraction.
-  // TODO: interface.
-  LDBG(""
-           << "Rewrite linalg op as vector.contract: ";
-       linalgOp.dump());
-  // Special function that describes how to vectorize the multiplication op in a
-  // linalg contraction.
-  CustomVectorizationHook vectorizeContraction =
-      [&](Operation *op,
-          const BlockAndValueMapping &bvm) -> VectorizationResult {
-    if (!isa<arith::MulIOp, arith::MulFOp>(op))
-      return VectorizationResult{VectorizationStatus::Failure, nullptr};
-    ArrayRef<int64_t> outShape =
-        linalgOp.getShape(linalgOp.getOutputOperand(0));
-    Type vType;
-    if (outShape.empty()) {
-      vType = op->getResult(0).getType();
-    } else {
-      SmallVector<int64_t> resultShape = applyPermutationMap(
-          inversePermutation(reindexIndexingMap(
-              linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))),
-          outShape);
-      vType = VectorType::get(resultShape, op->getResult(0).getType());
-    }
-    auto zero = b.create<arith::ConstantOp>(loc, vType, b.getZeroAttr(vType));
-    // Indexing maps at the time of vector.transfer_read are adjusted to order
-    // vector dimensions in the same order as the canonical linalg op iteration
-    // space order.
-    // The indexings for the contraction therefore need to be adjusted.
-    // TODO: consider dropping contraction special casing altogether, this will
-    // require more advanced canonicalizations involving vector.multi_reduction
-    // that are not yet available.
-    SmallVector<AffineMap> indexingMaps;
-    indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
-    llvm::transform(linalgOp.getIndexingMaps(),
-                    std::back_inserter(indexingMaps),
-                    [](AffineMap indexingMap) {
-                      return inversePermutation(reindexIndexingMap(indexingMap))
-                          .compose(indexingMap);
-                    });
-    Operation *contract = b.create<vector::ContractionOp>(
-        loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
-        b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
-    return VectorizationResult{VectorizationStatus::NewOp, contract};
-  };
-  return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
-                                  /*broadcastToMaximalCommonShape=*/false,
-                                  {vectorizeContraction});
-}
-
 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
   return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
     return m.isProjectedPermutation(/*allowZerosInResults=*/true);
@@ -674,8 +614,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   }
   if (isElementwise(op))
     return success();
-  if (isaContractionOpInterface(linalgOp))
-    return success();
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
   // But we will still need stride/dilation attributes that will be annoying to
   // reverse-engineer...
@@ -702,8 +640,6 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
     return failure();
 
   auto linalgOp = cast<LinalgOp>(op);
-  if (isaContractionOpInterface(linalgOp))
-    return vectorizeContraction(b, linalgOp, newResults);
 
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
   // But we will still need stride/dilation attributes that will be annoying to
@@ -721,8 +657,7 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
        << "Vectorize linalg op as a generic by broadcasting to "
           "maximal common shape: "
        << *op);
-  return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
-                                  /*broadcastToMaximalCommonShape=*/true);
+  return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
 }
 
 //----------------------------------------------------------------------------//

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index feb6309ba8102..ab0be6bbeebf4 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -25,7 +25,7 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 //
 //      CHECK-1D: vector.contract
 // CHECK-1D-SAME:   iterator_types = ["parallel", "parallel", "reduction"]
-// CHECK-1D-SAME:   : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32>
+// CHECK-1D-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
 //
 //      CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32, #{{.*}}>, vector<8x12xf32>
 //      CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32, #{{.*}}>
@@ -41,6 +41,6 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 //
 //      CHECK-2D: vector.contract
 // CHECK-2D-SAME:   iterator_types = ["parallel", "parallel", "reduction"]
-// CHECK-2D-SAME:   : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32>
+// CHECK-2D-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
 //
 //      CHECK-2D: linalg.copy

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 206bd7f94a5ce..7e84ee92ff8b3 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -4,8 +4,10 @@
 
 // CHECK-LABEL: contraction_dot
 func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
-  // CHECK: vector.contract
-  // CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32
+
+// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
+// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
+// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
   linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
             outs(%C: memref<f32>)
   return
@@ -15,8 +17,10 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
 
 // CHECK-LABEL: contraction_matvec
 func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
-  // CHECK: vector.contract
-  // CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
+
+// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
+// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
+// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
   linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
             outs(%C: memref<1584xf32>)
   return
@@ -26,8 +30,9 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
 
 // CHECK-LABEL: contraction_matmul
 func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
-  // CHECK: vector.contract
-  // CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
+// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
+// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
+// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
   linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
             outs(%C: memref<1584x1584xf32>)
   return
@@ -37,8 +42,9 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
 
 // CHECK-LABEL: contraction_batch_matmul
 func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
-  // CHECK: vector.contract
-  // CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
+// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
+// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
+// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
   linalg.batch_matmul
     ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
    outs(%C: memref<1584x1584x1584xf32>)
@@ -58,19 +64,15 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-// CHECK-DAG: #[[$trans_2d:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
 // CHECK-LABEL: func @vectorization_test
 func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
-  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
-  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
-  //  CHECK-SAME:   vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
+  //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
+  //       CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -96,19 +98,15 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-// CHECK-DAG: #[[$trans_2d:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
 // CHECK-LABEL: func @generic_output_transpose
 func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<32x8xf32>) {
-  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
-  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
-  //  CHECK-SAME:   vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
+  //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
+  //       CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
   linalg.generic #matmul_transpose_out_trait
     ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -134,19 +132,16 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-// CHECK-DAG: #[[$trans_2d:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
 // CHECK-LABEL: func @vectorization_test_integer
 func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
                                  %C: memref<8x32xi32>) {
-  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
-  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]],
-  //  CHECK-SAME:   vector<8x16xi32>, vector<32x16xi32> into vector<8x32xi32>
+  //       CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
+  //       CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
+
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
@@ -164,8 +159,9 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
 // CHECK-LABEL: func @vectorization_test_2
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
-  //       CHECK: vector.contract {{.*}} :
-  //                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+  //       CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
+  //       CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
+  //       CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
   linalg.matmul
     ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
    outs(%C: memref<8x32xf32>)
@@ -520,19 +516,16 @@ func @matmul_tensors(
   %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
     -> tensor<8x12xf32> {
   //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-  //   CHECK-DAG:   %[[VEC_C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x12xf32>
-  //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
-  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32>
+  //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32>
+  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32>
   //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
   //
-  // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
-  // a later canonicalization fuses the add into vector.contract.
-  //       CHECK:   %[[C:.*]] = vector.contract
-  //  CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
-  //  CHECK-SAME:     %[[V0]], %[[V1]], %[[VEC_C0]] :
-  //  CHECK-SAME:     vector<8x4xf32>, vector<12x4xf32> into vector<8x12xf32>
-  //       CHECK:   %[[C2:.*]] = arith.addf %[[V2]], %[[C]] : vector<8x12xf32>
-  //       CHECK:   %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
+  // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
+  // convert it to a 2D contract.
+  //       CHECK:   %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
+  //       CHECK:   %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
+  //       CHECK:   %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
+  //       CHECK:   %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
                      outs(%arg2: tensor<8x12xf32>)
     -> tensor<8x12xf32>

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 96b761c2cfc47..b1468083f52df 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -531,6 +531,14 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
                                           stage1Patterns);
   }
+  {
+    // Canonicalization patterns
+    RewritePatternSet canonicalizationPatterns(funcOp.getContext());
+    vector::populateVectorTransferPermutationMapLoweringPatterns(
+        canonicalizationPatterns);
+    vector::populateVetorReductionToContractPatterns(canonicalizationPatterns);
+    stage1Patterns.push_back(std::move(canonicalizationPatterns));
+  }
   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
   FrozenRewritePatternSet stage2Patterns =


        


More information about the Mlir-commits mailing list