[Mlir-commits] [mlir] 02d053e - [mlir][Vector] Add a canonicalization pattern for vector.contract + add
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Feb 15 13:25:50 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-15T21:22:36Z
New Revision: 02d053ed2d2ef626c3fc747f5224fad605b46060
URL: https://github.com/llvm/llvm-project/commit/02d053ed2d2ef626c3fc747f5224fad605b46060
DIFF: https://github.com/llvm/llvm-project/commit/02d053ed2d2ef626c3fc747f5224fad605b46060.diff
LOG: [mlir][Vector] Add a canonicalization pattern for vector.contract + add
Differential Revision: https://reviews.llvm.org/D96701
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0f80b753b2c2..a7c12231a91f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -246,6 +246,8 @@ def Vector_ContractionOp :
return CombiningKind::ADD;
}
}];
+
+ let hasCanonicalizer = 1;
}
def Vector_ReductionOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 678205d0b5d2..671cd865b1c3 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
@@ -658,6 +659,66 @@ Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
return shape;
}
+/// Return a fused vector::ContractionOp which represents a patterns such as:
+///
+/// ```mlir
+/// %c0 = vector.constant 0: ...
+/// %c = vector.contract %a, %b, %c0: ...
+/// %e = add %c, %d: ...
+/// ```
+///
+/// by:
+///
+/// ```mlir
+/// %e = vector.contract %a, %b, %d: ...
+/// ```
+///
+/// Return null if the canonicalization does not apply.
+// TODO: This should be a folding of Add into Contract in core but while they
+// live in
diff erent dialects, it is not possible without unnatural
+// dependencies.
+template <typename AddOpType>
+struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
+ using OpRewritePattern<AddOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AddOpType addOp,
+ PatternRewriter &rewriter) const override {
+ auto canonicalize = [&](Value maybeContraction,
+ Value otherOperand) -> vector::ContractionOp {
+ vector::ContractionOp contractionOp =
+ dyn_cast_or_null<vector::ContractionOp>(
+ maybeContraction.getDefiningOp());
+ if (!contractionOp)
+ return vector::ContractionOp();
+ if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
+ contractionOp.acc().getDefiningOp())) {
+ if (maybeZero.value() ==
+ rewriter.getZeroAttr(contractionOp.acc().getType())) {
+ BlockAndValueMapping bvm;
+ bvm.map(contractionOp.acc(), otherOperand);
+ auto newContraction =
+ cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
+ rewriter.replaceOp(addOp, newContraction.getResult());
+ return newContraction;
+ }
+ }
+ return vector::ContractionOp();
+ };
+
+ Value a = addOp->getOperand(0), b = addOp->getOperand(1);
+ vector::ContractionOp contract = canonicalize(a, b);
+ contract = contract ? contract : canonicalize(b, a);
+ return success();
+ }
+};
+
+void ContractionOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results
+ .insert<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
+ context);
+}
+
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9d810e17bcb5..d665a2b49d5e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,3 +710,49 @@ func @dead_load(%base: memref<?xf32>, %indices: vector<16xi32>,
memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return
}
+
+// -----
+
+#contraction_accesses0 = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @contractions
+// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: vector<2x3xf32>
+// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: vector<3x4xf32>
+// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: vector<2x4xf32>
+// CHECK-SAME: %[[A_I8:[0-9a-zA-Z]+]]: vector<2x3xi8>
+// CHECK-SAME: %[[B_I8:[0-9a-zA-Z]+]]: vector<3x4xi8>
+// CHECK-SAME: %[[C_I8:[0-9a-zA-Z]+]]: vector<2x4xi8>
+func @contractions(%a: vector<2x3xf32>, %b: vector<3x4xf32>, %c: vector<2x4xf32>,
+ %a_i8: vector<2x3xi8>, %b_i8: vector<3x4xi8>, %c_i8: vector<2x4xi8>)
+ -> (vector<2x4xf32>, vector<2x4xi8>)
+{
+ // CHECK-NOT: constant
+ %vf_0 = constant dense <0.0>: vector<2x4xf32>
+ // CHECK-NOT: addf
+ // CHECK: %[[D:.*]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]]
+ %0 = vector.contract #contraction_trait0 %a, %b, %vf_0:
+ vector<2x3xf32>, vector<3x4xf32> into vector<2x4xf32>
+ // CHECK-NOT: addf
+ %1 = addf %0, %c: vector<2x4xf32>
+
+ // CHECK-NOT: constant
+ %vi8_0 = constant dense <0>: vector<2x4xi8>
+ // CHECK-NOT: addi
+ // CHECK: %[[D_I8:.*]] = vector.contract {{.*}} %[[A_I8]], %[[B_I8]], %[[C_I8]]
+ %i8_0 = vector.contract #contraction_trait0 %a_i8, %b_i8, %vi8_0:
+ vector<2x3xi8>, vector<3x4xi8> into vector<2x4xi8>
+ // CHECK-NOT: addi
+ %i8_1 = addi %i8_0, %c_i8: vector<2x4xi8>
+
+ // CHECK: return %[[D]], %[[D_I8]]
+ return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8>
+}
+
More information about the Mlir-commits
mailing list