[Mlir-commits] [mlir] fb7ef63 - [mlir][vector][nvgpu] Move MMA contraction preparation to VectorUtils
Jakub Kuderski
llvmlistbot at llvm.org
Thu Mar 9 11:57:23 PST 2023
Author: Jakub Kuderski
Date: 2023-03-09T14:56:21-05:00
New Revision: fb7ef637a84652dbd3d973a1ba7db9470181b5aa
URL: https://github.com/llvm/llvm-project/commit/fb7ef637a84652dbd3d973a1ba7db9470181b5aa
DIFF: https://github.com/llvm/llvm-project/commit/fb7ef637a84652dbd3d973a1ba7db9470181b5aa.diff
LOG: [mlir][vector][nvgpu] Move MMA contraction preparation to VectorUtils
This pattern is not specific to nvgpu; I intend to use in SPIR-V codegen. `VectorTransforms` seems like a more generally useful place.
In addition:
- Fix a bug in the second condition (the dimensions were swapped for RHS).
- Add tests.
- Add support for externally provided filter functions, similar to other vector transforms.
- Prefer to transpose before zero/sign-extending inputs.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D145638
Added:
mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir
Modified:
mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
index 5880b09161c7c..003a160985ee1 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
@@ -93,18 +93,6 @@ FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
const LdMatrixParams ¶ms);
-/// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be
-/// converted to `nvgpu.mma.sync`. This specific form is meant to indicate that
-/// the vector operands are organized such that the reduction dimension is
-/// contiguous.
-struct PrepareContractToGPUMMASync
- : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-};
-
} // namespace nvgpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 775bfbf0241bf..1d572435cc2cd 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
namespace mlir {
class RewritePatternSet;
@@ -147,6 +148,27 @@ void populateVectorContractLoweringPatterns(
VectorTransformsOptions options = VectorTransformsOptions(),
PatternBenefit benefit = 1);
+/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
+/// semantics to a contraction with MMT semantics (matrix matrix multiplication
+/// with the RHS transposed). This specific form is meant to have the vector
+/// operands are organized such that the reduction dimension is contiguous.
+/// Example:
+/// ```
+/// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+/// affine_map<(m, n, k) -> (n, k)>,
+/// affine_map<(m, n, k) -> (m, n)>],
+/// iterator_types = ["parallel", "parallel", "reduction"],
+/// kind = #vector.kind<add>} %a, %b, %c : ...
+/// ```
+///
+/// The `constraint` predicate is used to decide which `vector.contraction` ops
+/// to filter out.
+void populateVectorContractCanonicalizeMatmulToMMT(
+ RewritePatternSet &patterns,
+ std::function<LogicalResult(vector::ContractionOp)> constraint =
+ [](vector::ContractionOp) { return success(); },
+ PatternBenefit = 1);
+
/// Collect patterns to convert reduction op to vector.contract and fold
/// transpose/broadcast ops into the contract.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index d9533b4e16b44..cc9813648bcc1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -1173,9 +1174,8 @@ void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
patterns.getContext());
return;
}
- patterns
- .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
- patterns.getContext());
+ vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
+ patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
}
LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 44f9b6d4ea012..7525f9f57bc5f 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -272,60 +272,3 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
return failure();
}
-
-LogicalResult nvgpu::PrepareContractToGPUMMASync::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
- Location loc = op.getLoc();
- Value lhs = op.getLhs();
- Value rhs = op.getRhs();
- Value res = op.getAcc();
-
- // Set up the parallel/reduction structure in right form.
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
- AffineExpr m;
- AffineExpr n;
- AffineExpr k;
- bindDims(rewriter.getContext(), m, n, k);
- static constexpr std::array<int64_t, 2> perm = {1, 0};
- auto iteratorTypes = op.getIteratorTypes().getValue();
- SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
- if (iteratorTypes.size() != 3)
- return failure();
- if (!(vector::isParallelIterator(iteratorTypes[0]) &&
- vector::isParallelIterator(iteratorTypes[1]) &&
- vector::isReductionIterator(iteratorTypes[2])))
- return failure();
-
- // The canonical form is "TNT" = A row-major, B col-major, C row-major.
- const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
- if (maps == canonicalForm) {
- return failure();
- }
- if (maps == infer({{m, k}, {k, n}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
- std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
- std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
- std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
- std::swap(lhs, rhs);
- } else {
- return failure();
- }
- rewriter.replaceOpWithNewOp<vector::ContractionOp>(
- op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
- op.getIteratorTypes());
- return success();
-}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 0844fda09328d..9e9e999edf048 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include <functional>
#include <optional>
#include <type_traits>
@@ -24,6 +25,8 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -31,6 +34,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
@@ -3053,6 +3057,104 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
}
};
+/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
+/// semantics to a contraction suitable for MMT (matrix matrix multiplication
+/// with the RHS transposed) lowering.
+struct CanonicalizeContractMatmulToMMT final
+ : OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
+ FilterConstraintType constraint)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ filter(std::move(constraint)) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO: Remove native masks from contraction op?
+ if (!op.getMasks().empty())
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+ Value res = op.getAcc();
+
+ // Set up the parallel/reduction structure in right form.
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m;
+ AffineExpr n;
+ AffineExpr k;
+ bindDims(rewriter.getContext(), m, n, k);
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ auto iteratorTypes = op.getIteratorTypes().getValue();
+ SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
+ if (iteratorTypes.size() != 3 ||
+ !vector::isParallelIterator(iteratorTypes[0]) ||
+ !vector::isParallelIterator(iteratorTypes[1]) ||
+ !vector::isReductionIterator(iteratorTypes[2]))
+ return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
+
+ // The canonical form is "TNT" = A row-major, B col-major, C row-major.
+ const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
+ if (maps == canonicalForm)
+ return rewriter.notifyMatchFailure(op, "already in the canonical form");
+
+ // Create a vector transpose making sure to emit zero/sign-extend at the
+ // end.
+ auto createTranspose = [&rewriter, loc](Value mat) -> Value {
+ if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
+ Value trans =
+ rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
+ return rewriter.create<arith::ExtSIOp>(loc, mat.getType(), trans);
+ }
+ if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
+ Value trans =
+ rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
+ return rewriter.create<arith::ExtUIOp>(loc, mat.getType(), trans);
+ }
+ return rewriter.create<vector::TransposeOp>(loc, mat, perm);
+ };
+
+ if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+ rhs = createTranspose(rhs);
+ } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ rhs = createTranspose(rhs);
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+ std::swap(rhs, lhs);
+ rhs = createTranspose(rhs);
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+ std::swap(rhs, lhs);
+ rhs = createTranspose(rhs);
+ } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+ std::swap(lhs, rhs);
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+ std::swap(lhs, rhs);
+ } else {
+ return rewriter.notifyMatchFailure(op, "unhandled contraction form");
+ }
+ rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+ op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
+ op.getIteratorTypes());
+ return success();
+ };
+
+private:
+ FilterConstraintType filter;
+};
+
} // namespace
void mlir::vector::populateVectorMaskMaterializationPatterns(
@@ -3104,6 +3206,14 @@ void mlir::vector::populateVectorContractLoweringPatterns(
options, patterns.getContext(), benefit);
}
+void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
+ RewritePatternSet &patterns,
+ std::function<LogicalResult(vector::ContractionOp)> constraint,
+ PatternBenefit benefit) {
+ patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
+ std::move(constraint));
+}
+
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir
new file mode 100644
index 0000000000000..d0be1230b46b3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir
@@ -0,0 +1,198 @@
+// RUN: mlir-opt %s -test-vector-contraction-prepare-for-mmt-lowering | FileCheck %s
+
+// CHECK-LABEL: func.func @not_matmul
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>, [[ARG2:%.+]]: f32)
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: return
+func.func @not_matmul(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
+ %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 :
+ vector<4xf32>, vector<4xf32> into f32
+ return %0 : f32
+}
+
+// This contraction is already in the canonical form.
+// CHECK-LABEL: func.func @matmul_mk_nk_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[LHS:%.+]] = arith.extsi [[ARG0]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-NEXT: [[RHS:%.+]] = arith.extsi [[TRANS]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs = arith.extsi %arg0: vector<4x4xi8> to vector<4x4xi32>
+ %rhs = arith.extsi %arg1: vector<4x4xi8> to vector<4x4xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// Check that non-square shapes are also handled.
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x16xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x16xi32>, [[ARG1:%.+]]: vector<16x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<16x4xi32> to vector<4x16xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x16xi32(%arg0: vector<4x16xi32>, %arg1: vector<16x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extui_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[LHS:%.+]] = arith.extui [[ARG0]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-NEXT: [[RHS:%.+]] = arith.extui [[TRANS]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi8_extui_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs = arith.extui %arg0: vector<4x4xi8> to vector<4x4xi32>
+ %rhs = arith.extui %arg1: vector<4x4xi8> to vector<4x4xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_nk_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[TRANS]], [[ARG1]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHST:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-DAG: [[LHS:%.+]] = arith.extsi [[LHST]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-DAG: [[RHST:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-DAG: [[RHS:%.+]] = arith.extui [[RHST]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs = arith.extsi %arg0 : vector<4x4xi8> to vector<4x4xi32>
+ %rhs = arith.extui %arg1 : vector<4x4xi8> to vector<4x4xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_nk_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[ARG0]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[LHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[ARG0]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_nk_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[LHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 68edea5efca09..93736dade444e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -11,6 +11,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -199,6 +200,33 @@ struct TestVectorContractionLowering
}
};
+struct TestVectorContractionPrepareForMMTLowering
+ : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorContractionPrepareForMMTLowering)
+
+ StringRef getArgument() const final {
+ return "test-vector-contraction-prepare-for-mmt-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test vector.contraction matmul canonicalization for MMT lowering.";
+ }
+ TestVectorContractionPrepareForMMTLowering() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, arith::ArithDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestVectorTransposeLowering
: public PassWrapper<TestVectorTransposeLowering,
OperationPass<func::FuncOp>> {
@@ -892,6 +920,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorContractionLowering>();
+ PassRegistration<TestVectorContractionPrepareForMMTLowering>();
+
PassRegistration<TestVectorTransposeLowering>();
PassRegistration<TestVectorUnrollingPatterns>();
More information about the Mlir-commits
mailing list