[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to sequence of FMAs (PR #163382)
Arun Thangamani
llvmlistbot at llvm.org
Thu Oct 23 23:08:43 PDT 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/163382
>From ce42a09598082c2ca7562bbf80312f5b9c3d60f8 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 14 Oct 2025 05:29:32 -0700
Subject: [PATCH 1/4] lower vector.contract to sequence of FMAs
---
.../mlir/Dialect/X86Vector/Transforms.h | 9 +
.../X86Vector/Transforms/CMakeLists.txt | 1 +
.../X86Vector/Transforms/NanoKernels.cpp | 483 ++++++++++++++++++
3 files changed, 493 insertions(+)
create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..cde890038f20a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
#include "mlir/IR/Value.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class ImplicitLocOpBuilder;
@@ -79,6 +83,11 @@ struct MaskHelper {
}
};
+//===----------------------------------------------------------------------===//
+// Nano-kernels
+LogicalResult nanoKernels(RewriterBase &rewriter,
+ vector::ContractionOp contractOp, int64_t vectorSize);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..da377763331f2 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
+ NanoKernels.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
new file mode 100644
index 0000000000000..bc03b567e06ff
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -0,0 +1,483 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 and BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+static FailureOr<SmallVector<scf::ForOp>>
+getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) {
+ SmallVector<scf::ForOp> list;
+ Operation *current = contractOp;
+ // It is register tiled loop structure on batch reduce matmul
+ // (M->N->Batch-reduce->K).
+ for (int i = 0; i < dimCount; i++) {
+ Operation *parent = current->getParentOfType<scf::ForOp>();
+ if (!parent)
+ return failure();
+ list.push_back(dyn_cast<scf::ForOp>(parent));
+ current = parent;
+ }
+ return list;
+}
+
+static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
+ SmallVector<memref::SubViewOp> subviews,
+ int64_t dimCount) {
+ auto subviewOpLhsOffsets = subviews[0].getOffsets();
+ auto subviewOpRhsOffsets = subviews[1].getOffsets();
+ auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+ if (dimCount == 4) {
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivReduction = loops[1].getInductionVar();
+ if (ivReduction != subviewOpLhsOffsets[0] ||
+ ivReduction != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[2].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+ return failure();
+
+ Value ivM = loops[3].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+ return failure();
+ }
+
+ if (dimCount == 3) {
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[1].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivM = loops[2].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+ return failure();
+ }
+
+ return success();
+}
+
+static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
+ Type elementType, int64_t M, int64_t N,
+ int64_t vectorSize, Value subviewOpAcc) {
+
+ SmallVector<Value> loopItrArgs;
+ int64_t outerBound = M;
+ int64_t innerBound = N;
+
+ int64_t outerStep = 1;
+ int64_t innerStep = vectorSize;
+
+ if ((N / vectorSize) > M) {
+ outerBound = N;
+ innerBound = M;
+
+ outerStep = vectorSize;
+ innerStep = 1;
+ }
+
+ for (int i = 0; i < outerBound; i = i + outerStep) {
+ for (int j = 0; j < innerBound; j = j + innerStep) {
+ Value indexOp_A = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, j);
+
+ if ((N / vectorSize) > M) {
+ indexOp_A = indexOp_B;
+ indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ }
+
+ auto valueCRow = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+ ValueRange{indexOp_A, indexOp_B});
+ loopItrArgs.push_back(valueCRow);
+ }
+ }
+
+ return loopItrArgs;
+}
+
+SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
+ Type elementType, int64_t vectorSize,
+ int64_t vnni, int64_t M, int64_t N,
+ ValueRange acc, Value matA, Value matB,
+ int64_t dimCount) {
+
+ SmallVector<Value> accVector;
+ SmallVector<Value> matLoad;
+ Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+ int64_t outerBound = M;
+ int64_t outerStep = 1;
+
+ int64_t innerBound = N;
+ int64_t innerStep = vectorSize;
+
+ Value outerMatrix = matA;
+ Value innerMatrix = matB;
+
+ int64_t outerVectSize = vnni;
+ int64_t innerVectSize = vectorSize;
+
+ int64_t fmaBound = M;
+
+ if ((N / vectorSize) < M) {
+ outerBound = N;
+ innerBound = M;
+
+ outerStep = vectorSize;
+ innerStep = 1;
+
+ outerMatrix = matB;
+ innerMatrix = matA;
+
+ outerVectSize = vectorSize;
+ innerVectSize = vnni;
+
+ fmaBound = N / vectorSize;
+ }
+
+ for (int i = 0; i < outerBound; i = i + outerStep) {
+ Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value valueRow;
+
+ if ((N / vectorSize) > M) {
+
+ SmallVector<Value> index = {c0, indexOp_i, c0};
+ if (dimCount == 3) {
+ index.erase(index.begin());
+ }
+
+ Value row = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(outerVectSize, elementType), outerMatrix, index);
+ valueRow = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(vectorSize, rewriter.getF32Type()), row);
+ } else {
+
+ SmallVector<Value> index = {c0, c0, indexOp_i};
+ if (dimCount == 3) {
+ index.erase(index.begin());
+ }
+
+ valueRow = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(outerVectSize, elementType), outerMatrix, index);
+ }
+
+ matLoad.push_back(valueRow);
+ }
+
+ for (int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+ Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(loc, j);
+ Value valueRow;
+
+ if ((N / vectorSize) < M) {
+ SmallVector<Value> index = {c0, indexOp_j, c0};
+ if (dimCount == 3) {
+ index.erase(index.begin());
+ }
+ Value row = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(innerVectSize, elementType), innerMatrix,
+ ValueRange(index));
+ valueRow = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(vectorSize, rewriter.getF32Type()), row);
+ } else {
+
+ SmallVector<Value> index = {c0, c0, indexOp_j};
+ if (dimCount == 3) {
+ index.erase(index.begin());
+ }
+
+ valueRow = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(innerVectSize, elementType), innerMatrix, index);
+ }
+
+ for (int i = 0; i < fmaBound; i = i + 1) {
+ auto fmaOdd =
+ rewriter.create<vector::FMAOp>(loc, matLoad[i], valueRow, acc[k]);
+ k++;
+ accVector.push_back(fmaOdd);
+ }
+ }
+
+ return accVector;
+}
+
+Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
+ SmallVector<Value> FMAs, Value accVec, int64_t vecSize,
+ int64_t M, int64_t N) {
+
+ auto strides = rewriter.getI64ArrayAttr({1});
+ if ((N / vecSize) > M) {
+ for (int j = 0, k = 0; j < (N / vecSize); j++) {
+ for (int i = 0; i < M; i++) {
+ int64_t off = (j * vecSize) + (i * N);
+ auto offsets = rewriter.getI64ArrayAttr({off});
+ accVec = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, vecType, FMAs[k], accVec, offsets, strides);
+ k++;
+ }
+ }
+
+ } else {
+ for (int i = 0, k = 0; i < M * N; i = i + vecSize) {
+ auto offsets = rewriter.getI64ArrayAttr({i});
+ accVec = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, vecType, FMAs[k], accVec, offsets, strides);
+ k++;
+ }
+ }
+ return accVec;
+}
+
+scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+ vector::TransferReadOp vectorReadOpLhs,
+ vector::TransferReadOp vectorReadOpRhs,
+ Value ivNewReductionForOp, Type elementType,
+ int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
+ ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+ auto newKForOp = rewriter.create<scf::ForOp>(
+ kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+ kForOp.getStep(), iterArgsNewReductionForOp,
+ [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+ Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+ IRMapping mapping;
+ mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1),
+ ivNewReductionForOp);
+ mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3),
+ ivNewKForOp);
+ auto lhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+ ivNewReductionForOp);
+ rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2),
+ ivNewKForOp);
+ auto rhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+ auto evenFMAs =
+ nanoKernels(rewriter, kForOp.getLoc(), elementType, vectorSize,
+ vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0),
+ rhsClone->getResult(0), dimCount);
+
+ rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, evenFMAs);
+ });
+
+ return newKForOp;
+}
+
+scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+ vector::TransferReadOp vectorReadOpLhs,
+ vector::TransferReadOp vectorReadOpRhs, Type elementType,
+ int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
+ ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+
+ auto newKForOp = rewriter.create<scf::ForOp>(
+ kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+ kForOp.getStep(), iterArgsNewReductionForOp,
+ [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+ Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+ IRMapping mapping;
+ mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(2),
+ ivNewKForOp);
+ auto lhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+ ivNewKForOp);
+ auto rhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+ auto evenFMAs =
+ nanoKernels(rewriter, loc, elementType, vectorSize, vnni, M, N,
+ iterArgsNewKForOp, lhsClone->getResult(0),
+ rhsClone->getResult(0), dimCount);
+
+ rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, evenFMAs);
+ });
+
+ return newKForOp;
+}
+
+LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
+ vector::ContractionOp contractOp, int64_t vectorSize) {
+ auto loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ auto dimCount = contractOp.getRhsType().getRank() + 1;
+
+ if ((dimCount != 3) && (dimCount != 4))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects batch-reduce or batch matmuls");
+
+ // Get the M, N, K, and batch-reduce loops
+ auto loops = getNestedLoop(contractOp, dimCount);
+ if (failed(loops))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid loop nest in contract pattern");
+
+ auto nestedLoops = *loops;
+ scf::ForOp kForOp = nestedLoops[0];
+ scf::ForOp reductionForOp;
+
+ vector::TransferReadOp vectorReadOpAcc;
+
+ if (dimCount == 4) {
+ reductionForOp = nestedLoops[1];
+ vectorReadOpAcc =
+ reductionForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+ }
+
+ if (dimCount == 3) {
+ vectorReadOpAcc =
+ kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+ }
+
+ auto vectorReadOpLhs =
+ contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+ auto vectorReadOpRhs =
+ contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+
+ if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
+ return failure();
+
+ auto subviewOpAcc =
+ vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
+ auto subviewOpLhs =
+ vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+ auto subviewOpRhs =
+ vectorReadOpRhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+
+ if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs)
+ return failure();
+
+ SmallVector<memref::SubViewOp> subviews;
+ subviews.push_back(subviewOpLhs);
+ subviews.push_back(subviewOpRhs);
+ subviews.push_back(subviewOpAcc);
+
+ // The M, N, K, and batch-reduce loop iv should match the iv's
+ // used in the subviews
+ auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
+ if (failed(checkLoops))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Loops doesn't match the iv in subviews");
+
+ auto elementType =
+ (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
+
+ // TODO: Support for BF16 Type
+ if (!elementType.isF32())
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only, FP32 type is supported");
+
+ auto lhsType = dyn_cast<ShapedType>(vectorReadOpLhs.getType());
+ auto rhsType = dyn_cast<ShapedType>(vectorReadOpRhs.getType());
+
+ // Get M, N, and K dimension size
+ int64_t M = lhsType.getDimSize(lhsType.getRank() - 2);
+ int64_t N = rhsType.getDimSize(rhsType.getRank() - 1);
+ int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
+ int64_t vnni = 1;
+
+ if (K != 1)
+ return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
+
+ if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
+ return rewriter.notifyMatchFailure(contractOp,
+ "The reduction-dim should be 1");
+
+
+ if (dimCount == 4)
+ rewriter.setInsertionPoint(reductionForOp);
+
+ if (dimCount == 3)
+ rewriter.setInsertionPoint(kForOp);
+
+ // Load MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
+ SmallVector<Value> loopItrArgs =
+ loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+
+ // Create the batch-reduce and K-loop with acc vectors as the loop
+ // iterargs (batch-reduce matmul) + nanokernel generation
+ scf::ForOp newLoop;
+ if (dimCount == 4) {
+ newLoop = rewriter.create<scf::ForOp>(
+ reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+ reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
+ Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
+ scf::ForOp newKForOp = createLoop(
+ rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
+ ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
+ iterArgsNewReductionForOp, dimCount);
+
+ rewriterNewReductionForOp.create<scf::YieldOp>(
+ locNewReductionForOp, newKForOp.getResults());
+ });
+ }
+
+ // Create only the K-loop (batch matmul) + nanokernel generation
+ if (dimCount == 3) {
+ newLoop =
+ createLoop(rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
+ elementType, vectorSize, vnni, M, N, loopItrArgs, dimCount);
+ }
+
+
+ // Combine all acc vectors into a MxN C matrix
+ auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
+ auto zeroAttr =
+ DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
+ Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
+
+ accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N);
+
+ auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
+
+ // Replace all the use of vector.contract with results of nanokernels
+ if (dimCount == 4)
+ rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
+
+ if (dimCount == 3)
+ rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
+
+ return success();
+}
>From 176cb06fed4c7358603dba8453ab3c1d79910dbd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 22 Oct 2025 04:29:22 -0700
Subject: [PATCH 2/4] initial code to wrap as a Transform dialect
---
.../mlir/Dialect/X86Vector/CMakeLists.txt | 2 +
.../X86Vector/TransformOps/CMakeLists.txt | 4 ++
.../TransformOps/X86VectorTransformOps.td | 38 +++++++++++++++++
.../mlir/Dialect/X86Vector/Transforms.h | 4 +-
mlir/lib/Dialect/X86Vector/CMakeLists.txt | 1 +
.../X86Vector/TransformOps/CMakeLists.txt | 20 +++++++++
.../TransformOps/X86VectorTransformOps.cpp | 42 +++++++++++++++++++
.../X86Vector/Transforms/NanoKernels.cpp | 21 +++++++++-
8 files changed, 128 insertions(+), 4 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..23f0eebaebe34
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,38 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_nanokernel_lowering",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operation can be lowered to target
+ specific nanokernels.
+ }];
+
+ //let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+
+ //let assemblyFormat = [{
+ //(`vector_size` `=` $vector_size^)? attr-dict
+ //}];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index cde890038f20a..6ddb4e542ba54 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -85,8 +85,8 @@ struct MaskHelper {
//===----------------------------------------------------------------------===//
// Nano-kernels
-LogicalResult nanoKernels(RewriterBase &rewriter,
- vector::ContractionOp contractOp, int64_t vectorSize);
+void populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns); //RewriterBase &rewriter,
+ //vector::ContractionOp contractOp, int64_t vectorSize);
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8814547620e58
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+ X86VectorTransformOps.cpp
+
+ DEPENDS
+ MLIRX86VectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRVectorToLLVM
+ MLIRVectorTransforms
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRVectorDialect
+ MLIRVectorToSCF
+ MLIRX86VectorTransforms
+ )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..1570972db1855
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,42 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+
+
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorTransferLoweringPatterns(patterns);//,
+ //getVectorSize());
+}
+
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<X86VectorTransformDialectExtension>();
+}
+
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
index bc03b567e06ff..5aeba6cad0445 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -21,6 +21,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/PatternMatch.h"
@@ -331,9 +335,16 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
return newKForOp;
}
-LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
- vector::ContractionOp contractOp, int64_t vectorSize) {
+//LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
+ // vector::ContractionOp contractOp)//, int64_t vectorSize) {
+
+struct VectorContractNanokernelLowering final : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
auto loc = contractOp.getLoc();
+ int64_t vectorSize = 16;
if (contractOp.getKind() != vector::CombiningKind::ADD) {
return rewriter.notifyMatchFailure(contractOp,
@@ -481,3 +492,9 @@ LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
return success();
}
+};
+
+
+void x86vector::populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns) {
+ patterns.add<VectorContractNanokernelLowering>(patterns.getContext());
+}
>From 15f4c5d54ccf3c820097dc8f4e2dc8d2999e3fe8 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 22 Oct 2025 07:37:41 -0700
Subject: [PATCH 3/4] fix build errors in transform code
---
.../TransformOps/X86VectorTransformOps.h | 36 +++++++++++++++++++
.../X86Vector/TransformOps/CMakeLists.txt | 2 +-
.../TransformOps/X86VectorTransformOps.cpp | 30 +++++++++++++---
3 files changed, 63 insertions(+), 5 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..abb5da75e5bfd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,36 @@
+//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace x86vector {
+} // namespace vector
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
index 8814547620e58..5f85f7af60d01 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -14,7 +14,7 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
- MLIRVectorDialect
+ MLIRX86VectorDialect
MLIRVectorToSCF
MLIRX86VectorTransforms
)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 1570972db1855..5702106558409 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -14,11 +14,13 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
-#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
-#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::transform;
@@ -27,10 +29,31 @@ using namespace mlir::transform;
void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorTransferLoweringPatterns(patterns);//,
+ x86vector::populateVectorContractNanokernelLoweringPatterns(patterns);//,
//getVectorSize());
}
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ X86VectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86VectorTransformDialectExtension)
+
+ X86VectorTransformDialectExtension() {
+ declareGeneratedDialect<x86vector::X86VectorDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
@@ -39,4 +62,3 @@ void mlir::x86vector::registerTransformDialectExtension(
DialectRegistry ®istry) {
registry.addExtensions<X86VectorTransformDialectExtension>();
}
-
>From ca4e25291d23c9140fbbc4904f816437aa2d48bb Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 23 Oct 2025 23:08:27 -0700
Subject: [PATCH 4/4] Round1: env set up + clang format fix
---
.../TransformOps/X86VectorTransformOps.h | 11 +-
.../TransformOps/X86VectorTransformOps.td | 9 +-
.../mlir/Dialect/X86Vector/Transforms.h | 4 +-
.../X86Vector/TransformOps/CMakeLists.txt | 3 -
.../TransformOps/X86VectorTransformOps.cpp | 23 +-
.../X86Vector/Transforms/NanoKernels.cpp | 363 +++++++++---------
mlir/lib/RegisterAllExtensions.cpp | 2 +
.../vector-contract-to-nanokernels.mlir | 48 +++
8 files changed, 256 insertions(+), 207 deletions(-)
create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
index abb5da75e5bfd..e1d8b8762e799 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -1,4 +1,4 @@
-//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===//
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,13 +12,8 @@
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
-namespace mlir {
-namespace x86vector {
-} // namespace vector
-} // namespace mlir
-
//===----------------------------------------------------------------------===//
-// Vector Transform Operations
+// X86Vector Transform Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
@@ -30,7 +25,7 @@ class DialectRegistry;
namespace x86vector {
void registerTransformDialectExtension(DialectRegistry ®istry);
-} // namespace vector
+} // namespace x86vector
} // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 23f0eebaebe34..9db2b36a2a8aa 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -25,13 +25,12 @@ def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
specific nanokernels.
}];
- //let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+ let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
- //let assemblyFormat = [{
- //(`vector_size` `=` $vector_size^)? attr-dict
- //}];
+ let assemblyFormat = [{
+ (`vector_size` `=` $vector_size^)? attr-dict
+ }];
- let assemblyFormat = "attr-dict";
}
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 6ddb4e542ba54..a9487adba002a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -85,8 +85,8 @@ struct MaskHelper {
//===----------------------------------------------------------------------===//
// Nano-kernels
-void populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns); //RewriterBase &rewriter,
- //vector::ContractionOp contractOp, int64_t vectorSize);
+void populateVectorContractNanokernelLoweringPatterns(
+ RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
index 5f85f7af60d01..f4c9f8a05acbc 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -9,12 +9,9 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
- MLIRVectorToLLVM
- MLIRVectorTransforms
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
MLIRX86VectorDialect
- MLIRVectorToSCF
MLIRX86VectorTransforms
)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 5702106558409..e003e3ad7cd08 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -1,4 +1,5 @@
-//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops --===//
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,31 +7,26 @@
//
//===----------------------------------------------------------------------===//
-
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
-
-#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"
-#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-
using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::transform;
-
-
-void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns(
- RewritePatternSet &patterns) {
- x86vector::populateVectorContractNanokernelLoweringPatterns(patterns);//,
- //getVectorSize());
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractNanokernelLoweringPatterns(patterns,
+ getVectorSize());
}
//===----------------------------------------------------------------------===//
@@ -42,7 +38,8 @@ class X86VectorTransformDialectExtension
: public transform::TransformDialectExtension<
X86VectorTransformDialectExtension> {
public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86VectorTransformDialectExtension)
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ X86VectorTransformDialectExtension)
X86VectorTransformDialectExtension() {
declareGeneratedDialect<x86vector::X86VectorDialect>();
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
index 5aeba6cad0445..583334333a49d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -16,30 +16,28 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
-
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
static FailureOr<SmallVector<scf::ForOp>>
-getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) {
+getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) {
SmallVector<scf::ForOp> list;
Operation *current = contractOp;
// It is register tiled loop structure on batch reduce matmul
// (M->N->Batch-reduce->K).
- for (int i = 0; i < dimCount; i++) {
+ for (unsigned int i = 0; i < dimCount; i++) {
Operation *parent = current->getParentOfType<scf::ForOp>();
if (!parent)
return failure();
@@ -51,7 +49,7 @@ getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) {
static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
SmallVector<memref::SubViewOp> subviews,
- int64_t dimCount) {
+ unsigned int dimCount) {
auto subviewOpLhsOffsets = subviews[0].getOffsets();
auto subviewOpRhsOffsets = subviews[1].getOffsets();
auto subviewOpAccOffsets = subviews[2].getOffsets();
@@ -93,15 +91,16 @@ static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
}
static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
- Type elementType, int64_t M, int64_t N,
- int64_t vectorSize, Value subviewOpAcc) {
+ Type elementType, unsigned int M,
+ unsigned int N, unsigned int vectorSize,
+ Value subviewOpAcc) {
SmallVector<Value> loopItrArgs;
- int64_t outerBound = M;
- int64_t innerBound = N;
+ unsigned int outerBound = M;
+ unsigned int innerBound = N;
- int64_t outerStep = 1;
- int64_t innerStep = vectorSize;
+ unsigned int outerStep = 1;
+ unsigned int innerStep = vectorSize;
if ((N / vectorSize) > M) {
outerBound = N;
@@ -111,8 +110,8 @@ static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
innerStep = 1;
}
- for (int i = 0; i < outerBound; i = i + outerStep) {
- for (int j = 0; j < innerBound; j = j + innerStep) {
+ for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+ for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
Value indexOp_A = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, j);
@@ -132,28 +131,28 @@ static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
}
SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
- Type elementType, int64_t vectorSize,
- int64_t vnni, int64_t M, int64_t N,
- ValueRange acc, Value matA, Value matB,
- int64_t dimCount) {
+ Type elementType, unsigned int vectorSize,
+ unsigned int vnni, unsigned int M,
+ unsigned int N, ValueRange acc, Value matA,
+ Value matB, unsigned int dimCount) {
SmallVector<Value> accVector;
SmallVector<Value> matLoad;
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- int64_t outerBound = M;
- int64_t outerStep = 1;
+ unsigned int outerBound = M;
+ unsigned int outerStep = 1;
- int64_t innerBound = N;
- int64_t innerStep = vectorSize;
+ unsigned int innerBound = N;
+ unsigned int innerStep = vectorSize;
Value outerMatrix = matA;
Value innerMatrix = matB;
- int64_t outerVectSize = vnni;
- int64_t innerVectSize = vectorSize;
+ unsigned int outerVectSize = vnni;
+ unsigned int innerVectSize = vectorSize;
- int64_t fmaBound = M;
+ unsigned int fmaBound = M;
if ((N / vectorSize) < M) {
outerBound = N;
@@ -171,7 +170,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
fmaBound = N / vectorSize;
}
- for (int i = 0; i < outerBound; i = i + outerStep) {
+ for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value valueRow;
@@ -200,7 +199,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
matLoad.push_back(valueRow);
}
- for (int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+ for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(loc, j);
Value valueRow;
@@ -225,7 +224,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
loc, VectorType::get(innerVectSize, elementType), innerMatrix, index);
}
- for (int i = 0; i < fmaBound; i = i + 1) {
+ for (unsigned int i = 0; i < fmaBound; i = i + 1) {
auto fmaOdd =
rewriter.create<vector::FMAOp>(loc, matLoad[i], valueRow, acc[k]);
k++;
@@ -237,14 +236,14 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
}
Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
- SmallVector<Value> FMAs, Value accVec, int64_t vecSize,
- int64_t M, int64_t N) {
+ SmallVector<Value> FMAs, Value accVec, unsigned int vecSize,
+ unsigned int M, unsigned int N) {
auto strides = rewriter.getI64ArrayAttr({1});
if ((N / vecSize) > M) {
- for (int j = 0, k = 0; j < (N / vecSize); j++) {
- for (int i = 0; i < M; i++) {
- int64_t off = (j * vecSize) + (i * N);
+ for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) {
+ for (unsigned int i = 0; i < M; i++) {
+ unsigned int off = (j * vecSize) + (i * N);
auto offsets = rewriter.getI64ArrayAttr({off});
accVec = rewriter.create<vector::InsertStridedSliceOp>(
loc, vecType, FMAs[k], accVec, offsets, strides);
@@ -253,7 +252,7 @@ Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
}
} else {
- for (int i = 0, k = 0; i < M * N; i = i + vecSize) {
+ for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) {
auto offsets = rewriter.getI64ArrayAttr({i});
accVec = rewriter.create<vector::InsertStridedSliceOp>(
loc, vecType, FMAs[k], accVec, offsets, strides);
@@ -267,8 +266,10 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
vector::TransferReadOp vectorReadOpLhs,
vector::TransferReadOp vectorReadOpRhs,
Value ivNewReductionForOp, Type elementType,
- int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
- ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+ unsigned int vectorSize, unsigned int vnni,
+ unsigned int M, unsigned int N,
+ ValueRange iterArgsNewReductionForOp,
+ unsigned int dimCount) {
auto newKForOp = rewriter.create<scf::ForOp>(
kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
kForOp.getStep(), iterArgsNewReductionForOp,
@@ -304,8 +305,10 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
vector::TransferReadOp vectorReadOpLhs,
vector::TransferReadOp vectorReadOpRhs, Type elementType,
- int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
- ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+ unsigned int vectorSize, unsigned int vnni,
+ unsigned int M, unsigned int N,
+ ValueRange iterArgsNewReductionForOp,
+ unsigned int dimCount) {
auto newKForOp = rewriter.create<scf::ForOp>(
kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
@@ -335,166 +338,174 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
return newKForOp;
}
-//LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
- // vector::ContractionOp contractOp)//, int64_t vectorSize) {
-
-struct VectorContractNanokernelLowering final : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+struct VectorContractNanokernelLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+ VectorContractNanokernelLowering(MLIRContext *context,
+ std::optional<unsigned> vecSize)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ userVectorSize(vecSize) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- auto loc = contractOp.getLoc();
- int64_t vectorSize = 16;
-
- if (contractOp.getKind() != vector::CombiningKind::ADD) {
- return rewriter.notifyMatchFailure(contractOp,
- "Expects add combining kind");
- }
-
- auto dimCount = contractOp.getRhsType().getRank() + 1;
- if ((dimCount != 3) && (dimCount != 4))
- return rewriter.notifyMatchFailure(contractOp,
- "Expects batch-reduce or batch matmuls");
+ auto loc = contractOp.getLoc();
- // Get the M, N, K, and batch-reduce loops
- auto loops = getNestedLoop(contractOp, dimCount);
- if (failed(loops))
- return rewriter.notifyMatchFailure(contractOp,
- "Invalid loop nest in contract pattern");
+ unsigned int vectorSize = 8;
- auto nestedLoops = *loops;
- scf::ForOp kForOp = nestedLoops[0];
- scf::ForOp reductionForOp;
+ if (userVectorSize)
+ vectorSize = *userVectorSize;
- vector::TransferReadOp vectorReadOpAcc;
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
- if (dimCount == 4) {
- reductionForOp = nestedLoops[1];
- vectorReadOpAcc =
- reductionForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
- }
+ auto dimCount = contractOp.getRhsType().getRank() + 1;
- if (dimCount == 3) {
- vectorReadOpAcc =
- kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
- }
+ if ((dimCount != 3) && (dimCount != 4))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects batch-reduce or batch matmuls");
- auto vectorReadOpLhs =
- contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
- auto vectorReadOpRhs =
- contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+ // Get the M, N, K, and batch-reduce loops
+ auto loops = getNestedLoop(contractOp, dimCount);
+ if (failed(loops))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Invalid loop nest in contract pattern");
- if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
- return failure();
+ auto nestedLoops = *loops;
+ scf::ForOp kForOp = nestedLoops[0];
+ scf::ForOp reductionForOp;
- auto subviewOpAcc =
- vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
- auto subviewOpLhs =
- vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
- auto subviewOpRhs =
- vectorReadOpRhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+ vector::TransferReadOp vectorReadOpAcc;
- if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs)
- return failure();
-
- SmallVector<memref::SubViewOp> subviews;
- subviews.push_back(subviewOpLhs);
- subviews.push_back(subviewOpRhs);
- subviews.push_back(subviewOpAcc);
+ if (dimCount == 4) {
+ reductionForOp = nestedLoops[1];
+ vectorReadOpAcc = reductionForOp.getInitArgs()[0]
+ .getDefiningOp<vector::TransferReadOp>();
+ }
- // The M, N, K, and batch-reduce loop iv should match the iv's
- // used in the subviews
- auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
- if (failed(checkLoops))
- return rewriter.notifyMatchFailure(
- contractOp, "Loops doesn't match the iv in subviews");
+ if (dimCount == 3) {
+ vectorReadOpAcc =
+ kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+ }
- auto elementType =
- (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
+ auto vectorReadOpLhs =
+ contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+ auto vectorReadOpRhs =
+ contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
- // TODO: Support for BF16 Type
- if (!elementType.isF32())
- return rewriter.notifyMatchFailure(
- contractOp, "Only, FP32 type is supported");
+ if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
+ return failure();
- auto lhsType = dyn_cast<ShapedType>(vectorReadOpLhs.getType());
- auto rhsType = dyn_cast<ShapedType>(vectorReadOpRhs.getType());
+ auto subviewOpAcc =
+ vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
+ auto subviewOpLhs =
+ vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+ auto subviewOpRhs =
+ vectorReadOpRhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
- // Get M, N, and K dimension size
- int64_t M = lhsType.getDimSize(lhsType.getRank() - 2);
- int64_t N = rhsType.getDimSize(rhsType.getRank() - 1);
- int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
- int64_t vnni = 1;
+ if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs)
+ return failure();
- if (K != 1)
- return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
+ SmallVector<memref::SubViewOp> subviews;
+ subviews.push_back(subviewOpLhs);
+ subviews.push_back(subviewOpRhs);
+ subviews.push_back(subviewOpAcc);
+
+ // The M, N, K, and batch-reduce loop iv should match the iv's
+ // used in the subviews
+ auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
+ if (failed(checkLoops))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Loops doesn't match the iv in subviews");
+
+ auto elementType =
+ (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
+
+ // TODO: Support for BF16 Type
+ if (!elementType.isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only, FP32 type is supported");
+
+ auto lhsType = dyn_cast<ShapedType>(vectorReadOpLhs.getType());
+ auto rhsType = dyn_cast<ShapedType>(vectorReadOpRhs.getType());
+
+ // Get M, N, and K dimension size
+ unsigned int M = lhsType.getDimSize(lhsType.getRank() - 2);
+ unsigned int N = rhsType.getDimSize(rhsType.getRank() - 1);
+ unsigned int K = lhsType.getDimSize(lhsType.getRank() - 1);
+ unsigned int vnni = 1;
+
+ if (K != 1)
+ return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
+
+ if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
+ return rewriter.notifyMatchFailure(contractOp,
+ "The reduction-dim should be 1");
+
+ if (dimCount == 4)
+ rewriter.setInsertionPoint(reductionForOp);
+
+ if (dimCount == 3)
+ rewriter.setInsertionPoint(kForOp);
+
+ // Load MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
+ SmallVector<Value> loopItrArgs =
+ loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+
+ // Create the batch-reduce and K-loop with acc vectors as the loop
+ // iterargs (batch-reduce matmul) + nanokernel generation
+ scf::ForOp newLoop;
+ if (dimCount == 4) {
+ newLoop = rewriter.create<scf::ForOp>(
+ reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+ reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
+ [&](OpBuilder &rewriterNewReductionForOp,
+ Location locNewReductionForOp, Value ivNewReductionForOp,
+ ValueRange iterArgsNewReductionForOp) {
+ scf::ForOp newKForOp = createLoop(
+ rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
+ ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
+ iterArgsNewReductionForOp, dimCount);
+
+ rewriterNewReductionForOp.create<scf::YieldOp>(
+ locNewReductionForOp, newKForOp.getResults());
+ });
+ }
- if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
- return rewriter.notifyMatchFailure(contractOp,
- "The reduction-dim should be 1");
+ // Create only the K-loop (batch matmul) + nanokernel generation
+ if (dimCount == 3) {
+ newLoop = createLoop(rewriter, loc, kForOp, vectorReadOpLhs,
+ vectorReadOpRhs, elementType, vectorSize, vnni, M, N,
+ loopItrArgs, dimCount);
+ }
+ // Combine all acc vectors into a MxN C matrix
+ auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
+ auto zeroAttr =
+ DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
+ Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
- if (dimCount == 4)
- rewriter.setInsertionPoint(reductionForOp);
+ accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec,
+ vectorSize, M, N);
- if (dimCount == 3)
- rewriter.setInsertionPoint(kForOp);
+ auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
- // Load MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
- SmallVector<Value> loopItrArgs =
- loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+ // Replace all the use of vector.contract with results of nanokernels
+ if (dimCount == 4)
+ rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
- // Create the batch-reduce and K-loop with acc vectors as the loop
- // iterargs (batch-reduce matmul) + nanokernel generation
- scf::ForOp newLoop;
- if (dimCount == 4) {
- newLoop = rewriter.create<scf::ForOp>(
- reductionForOp.getLoc(), reductionForOp.getLowerBound(),
- reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
- [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
- Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
- scf::ForOp newKForOp = createLoop(
- rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
- ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
- iterArgsNewReductionForOp, dimCount);
-
- rewriterNewReductionForOp.create<scf::YieldOp>(
- locNewReductionForOp, newKForOp.getResults());
- });
- }
+ if (dimCount == 3)
+ rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
- // Create only the K-loop (batch matmul) + nanokernel generation
- if (dimCount == 3) {
- newLoop =
- createLoop(rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
- elementType, vectorSize, vnni, M, N, loopItrArgs, dimCount);
+ return success();
}
-
-
- // Combine all acc vectors into a MxN C matrix
- auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
- auto zeroAttr =
- DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
- Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
-
- accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N);
-
- auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
- auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
-
- // Replace all the use of vector.contract with results of nanokernels
- if (dimCount == 4)
- rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
-
- if (dimCount == 3)
- rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
-
- return success();
-}
+ std::optional<unsigned> userVectorSize;
};
-
-void x86vector::populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns) {
- patterns.add<VectorContractNanokernelLowering>(patterns.getContext());
+void x86vector::populateVectorContractNanokernelLoweringPatterns(
+ RewritePatternSet &patterns, std::optional<unsigned> userVectorSize) {
+ patterns.add<VectorContractNanokernelLowering>(patterns.getContext(),
+ userVectorSize);
}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..efcd09fc1b924 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
+ x86vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
new file mode 100644
index 0000000000000..78ff150bb776e
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+module {
+ func.func @fp32_vectorSize_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
+ %0 = ub.poison : f32
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c96 = arith.constant 96 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ scf.for %arg3 = %c0 to %c4 step %c4 {
+ scf.for %arg4 = %c0 to %c96 step %c96 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>>
+ %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32>
+ %2 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %1) -> (vector<4x96xf32>) {
+ %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x96xf32>) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>
+ %subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 96] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>>
+ %4 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32>
+ %5 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x96xf32>
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x96xf32> into vector<4x96xf32>
+ scf.yield %6 : vector<4x96xf32>
+ }
+ scf.yield %3 : vector<4x96xf32>
+ }
+ vector.transfer_write %2, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>>
+ }
+ }
+ return %arg2 : memref<4x96xf32>
+ }
+}
+
+// CHECK-LABEL: func.func @fp32_vectorSize_16(
+// CHECK-COUNT-24: vector.fma{{.*}}vector<16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list