[Mlir-commits] [mlir] 02d053e - [mlir][Vector] Add a canonicalization pattern for vector.contract + add

Nicolas Vasilache llvmlistbot at llvm.org
Mon Feb 15 13:25:50 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-15T21:22:36Z
New Revision: 02d053ed2d2ef626c3fc747f5224fad605b46060

URL: https://github.com/llvm/llvm-project/commit/02d053ed2d2ef626c3fc747f5224fad605b46060
DIFF: https://github.com/llvm/llvm-project/commit/02d053ed2d2ef626c3fc747f5224fad605b46060.diff

LOG: [mlir][Vector] Add a canonicalization pattern for vector.contract + add

Differential Revision: https://reviews.llvm.org/D96701

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0f80b753b2c2..a7c12231a91f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -246,6 +246,8 @@ def Vector_ContractionOp :
       return CombiningKind::ADD;
     }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def Vector_ReductionOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 678205d0b5d2..671cd865b1c3 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -658,6 +659,66 @@ Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
   return shape;
 }
 
+/// Return a fused vector::ContractionOp which represents a patterns such as:
+///
+/// ```mlir
+///    %c0 = vector.constant 0: ...
+///    %c = vector.contract %a, %b, %c0: ...
+///    %e = add %c, %d: ...
+/// ```
+///
+/// by:
+///
+/// ```mlir
+///    %e = vector.contract %a, %b, %d: ...
+/// ```
+///
+/// Return null if the canonicalization does not apply.
+// TODO: This should be a folding of Add into Contract in core but while they
+// live in 
diff erent dialects, it is not possible without unnatural
+// dependencies.
+template <typename AddOpType>
+struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
+  using OpRewritePattern<AddOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AddOpType addOp,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalize = [&](Value maybeContraction,
+                            Value otherOperand) -> vector::ContractionOp {
+      vector::ContractionOp contractionOp =
+          dyn_cast_or_null<vector::ContractionOp>(
+              maybeContraction.getDefiningOp());
+      if (!contractionOp)
+        return vector::ContractionOp();
+      if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
+              contractionOp.acc().getDefiningOp())) {
+        if (maybeZero.value() ==
+            rewriter.getZeroAttr(contractionOp.acc().getType())) {
+          BlockAndValueMapping bvm;
+          bvm.map(contractionOp.acc(), otherOperand);
+          auto newContraction =
+              cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
+          rewriter.replaceOp(addOp, newContraction.getResult());
+          return newContraction;
+        }
+      }
+      return vector::ContractionOp();
+    };
+
+    Value a = addOp->getOperand(0), b = addOp->getOperand(1);
+    vector::ContractionOp contract = canonicalize(a, b);
+    contract = contract ? contract : canonicalize(b, a);
+    return success();
+  }
+};
+
+void ContractionOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results
+      .insert<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
+          context);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractElementOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9d810e17bcb5..d665a2b49d5e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,3 +710,49 @@ func @dead_load(%base: memref<?xf32>, %indices: vector<16xi32>,
     memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return
 }
+
+// -----
+
+#contraction_accesses0 = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @contractions
+//  CHECK-SAME:   %[[A:[0-9a-zA-Z]+]]: vector<2x3xf32>
+//  CHECK-SAME:   %[[B:[0-9a-zA-Z]+]]: vector<3x4xf32>
+//  CHECK-SAME:   %[[C:[0-9a-zA-Z]+]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[A_I8:[0-9a-zA-Z]+]]: vector<2x3xi8>
+//  CHECK-SAME:   %[[B_I8:[0-9a-zA-Z]+]]: vector<3x4xi8>
+//  CHECK-SAME:   %[[C_I8:[0-9a-zA-Z]+]]: vector<2x4xi8>
+func @contractions(%a: vector<2x3xf32>, %b: vector<3x4xf32>, %c: vector<2x4xf32>,
+                   %a_i8: vector<2x3xi8>, %b_i8: vector<3x4xi8>, %c_i8: vector<2x4xi8>)
+  -> (vector<2x4xf32>, vector<2x4xi8>)
+{
+  // CHECK-NOT: constant
+  %vf_0 = constant dense <0.0>: vector<2x4xf32>
+  // CHECK-NOT: addf
+  //     CHECK: %[[D:.*]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]]
+  %0 = vector.contract #contraction_trait0 %a, %b, %vf_0:
+    vector<2x3xf32>, vector<3x4xf32> into vector<2x4xf32>
+  // CHECK-NOT: addf
+  %1 = addf %0, %c: vector<2x4xf32>
+
+  // CHECK-NOT: constant
+  %vi8_0 = constant dense <0>: vector<2x4xi8>
+  // CHECK-NOT: addi
+  //     CHECK: %[[D_I8:.*]] = vector.contract {{.*}} %[[A_I8]], %[[B_I8]], %[[C_I8]]
+  %i8_0 = vector.contract #contraction_trait0 %a_i8, %b_i8, %vi8_0:
+    vector<2x3xi8>, vector<3x4xi8> into vector<2x4xi8>
+  // CHECK-NOT: addi
+  %i8_1 = addi %i8_0, %c_i8: vector<2x4xi8>
+
+  // CHECK: return %[[D]], %[[D_I8]]
+  return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8>
+}
+


        


More information about the Mlir-commits mailing list