[Mlir-commits] [mlir] [mlir][x86]Convert a `linalg.generic` with BF16/Int8 accumulation to F32/Int32 accumulation (PR #190779)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 7 05:03:59 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Arun Thangamani (arun-thmn)
<details>
<summary>Changes</summary>
Rewrites a `linalg.generic` from low-precision (bf16/i8) to high-precision accumulation (f32/i32). Performs compute (mul + add) in higher precision, starting from a zero-initialized accumulator. Then adds the original output and casts (truncates) back to the original type.
This rewrite is helpful to target machine specific `dot-product` operations like `x86.avx.dot.i8`, `x86.avx10.dot.i8`, and `x86.avx512.dot` which accumulates on `f32/i32`.
---
Patch is 27.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/190779.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td (+11)
- (modified) mlir/include/mlir/Dialect/X86/Transforms.h (+8)
- (modified) mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp (+5)
- (modified) mlir/lib/Dialect/X86/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp (+221)
- (added) mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir (+290)
``````````diff
diff --git a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
index c474cfb47d003..15aba2a21c4a2 100644
--- a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
+++ b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
@@ -82,5 +82,16 @@ def ApplyVectorContractToAMXDotProductPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyConvertLinalgGenericTo32BitAccumulationPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects pattern to convert a linalg.generic from low-precision (bf16/i8) to high-precision
+ accumulation (f32/i32) and finally tuncates the output back to original type.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86/Transforms.h b/mlir/include/mlir/Dialect/X86/Transforms.h
index 6ebba5e94ec7c..749e82a4b4e36 100644
--- a/mlir/include/mlir/Dialect/X86/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86/Transforms.h
@@ -110,6 +110,14 @@ void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
// Int8).
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns);
+// Rewrites a linalg.generic from low-precision (bf16/i8) to high-precision
+// accumulation (f32/i32).
+// Performs compute (mul + add) in higher precision, starting from a
+// zero-initialized accumulator. Then adds the original output and casts
+// (truncates) back to the original type.
+void populateConvertLinalgGenericTo32BitAccumulationPatterns(
+ RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
index 390b21e12b0ed..fbddf19be8848 100644
--- a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
+++ b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
@@ -52,6 +52,11 @@ void mlir::transform::ApplyVectorContractToAMXDotProductPatternsOp::
x86::populateVectorContractToAMXDotProductPatterns(patterns);
}
+void mlir::transform::ApplyConvertLinalgGenericTo32BitAccumulationPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86::populateConvertLinalgGenericTo32BitAccumulationPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
index 9c3695536cda9..a4bcb17f9fdb0 100644
--- a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRX86Transforms
SinkVectorProducerOps.cpp
ShuffleVectorFMAOps.cpp
VectorContractToAMXDotProduct.cpp
+ ConvertLinalgGenericTo32BitAccumulation.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp b/mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp
new file mode 100644
index 0000000000000..2576864c58f64
--- /dev/null
+++ b/mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp
@@ -0,0 +1,221 @@
+//===- ConvertLinalgGenericTo32BitAccumulation.cpp------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/X86/Transforms.h"
+#include "mlir/Dialect/X86/X86Dialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+// Rewrites a linalg.generic from low-precision (bf16/i8) to high-precision
+// accumulation (f32/i32).
+// Performs compute (mul + add) in higher precision, starting from a
+// zero-initialized accumulator. Then adds the original output and casts
+// (truncates) back to the original type.
+//
+// Example:
+// Input:
+// linalg.generic ins(tensor<16x32xbf>, tensor<32x48xbf16>)
+// outs(tensor<16x48xbf16) {
+// arith.multf : bf16
+// arith.addf : bf16
+// } -> tensor<16x48xbf16>
+//
+// Output:
+// linalg.fill ins(f32) outs(tensor<16x48xf32>) -> tensor<16x48xf32>
+// linalg.generic ins(tensor<16x32xbf>, tensor<32x48xbf16>)
+// outs(tensor<16x48xf32) {
+// %a = arith.extf %in : bf16 to f32
+// %b = arith.extf %in_2 : bf16 to f32
+// %c = arith.mulf %a, %b : f32
+// arith.addf %out, %c : f32
+// } -> tensor<16x48xf32>
+//
+// linalg.generic ins(tensor<16x48xf32>, tensor<16x48xbf16>)
+// outs(tensor<16x48xbf16>) {
+// %a = arith.extf %in_2 : bf16 to f32
+// %b = arith.addf %in, %a : f32
+// %c = arith.truncf %b : f32 to bf16
+// } -> tensor<16x48xbf16>
+//
+struct ConvertLinalgGenericTo32BitAccumulation
+ : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!genericOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(genericOp,
+ "Support only for tensor type.");
+
+ if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+ return rewriter.notifyMatchFailure(genericOp,
+ "Needed two input tensors.");
+
+ auto outType =
+ llvm::dyn_cast<RankedTensorType>(genericOp.getResult(0).getType());
+
+ if (!outType)
+ return rewriter.notifyMatchFailure(genericOp, "No output type detected.");
+
+ if (!outType.getElementType().isBF16() &&
+ !outType.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ genericOp, "The outs type should be BF16 or Int8.");
+
+ Type ipType = rewriter.getBF16Type();
+ Type opType = rewriter.getF32Type();
+
+ if (outType.getElementType().isSignlessInteger(8)) {
+ ipType = rewriter.getIntegerType(8);
+ opType = rewriter.getIntegerType(32);
+ }
+
+ if (outType.getElementType().isBF16()) {
+ for (Operation &innerOp : genericOp.getRegion().front()) {
+ if (isa<arith::MulFOp, arith::AddFOp, linalg::YieldOp>(innerOp))
+ continue;
+
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "Upsupported operations inside linalg.generic's region.");
+ }
+ }
+
+ if (outType.getElementType().isSignlessInteger(8)) {
+ for (Operation &innerOp : genericOp.getRegion().front()) {
+ if (isa<arith::MulIOp, arith::AddIOp, linalg::YieldOp>(innerOp))
+ continue;
+
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "Upsupported operations inside linalg.generic's region.");
+ }
+ }
+
+ auto loc = genericOp.getLoc();
+ auto tensorType = RankedTensorType::get(outType.getShape(), opType);
+
+ // tensor.empty
+ auto empty =
+ tensor::EmptyOp::create(rewriter, loc, outType.getShape(), opType);
+
+ auto zeroAttr = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0);
+ auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
+ if (outType.getElementType().isSignlessInteger(8)) {
+ auto zeroAttrI32 =
+ rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0);
+ zero = arith::ConstantOp::create(rewriter, loc, zeroAttrI32);
+ }
+
+ // fill
+ auto fill = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{empty})
+ .getResult(0);
+
+ // ---- 3. Build new linalg.generic (32 accumulation) ----
+ auto newGeneric = linalg::GenericOp::create(
+ rewriter, loc,
+ tensorType, // result type
+ genericOp.getDpsInputs(), // same inputs
+ fill, // new init
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ // args: bf16/I8, bf16/I8, f32/I32
+ auto a = args[0];
+ auto bval = args[1];
+ auto acc = args[2];
+
+ Value sum;
+ if (outType.getElementType().isBF16()) {
+ // cast inputs
+ auto a32 = arith::ExtFOp::create(b, loc, opType, a);
+ auto b32 = arith::ExtFOp::create(b, loc, opType, bval);
+
+ // mul + add
+ auto mul = arith::MulFOp::create(b, loc, a32, b32);
+ sum = arith::AddFOp::create(b, loc, acc, mul);
+ }
+
+ if (outType.getElementType().isSignlessInteger(8)) {
+ // cast inputs
+ auto a32 = arith::ExtSIOp::create(b, loc, opType, a);
+ auto b32 = arith::ExtSIOp::create(b, loc, opType, bval);
+
+ // mul + add
+ auto mul = arith::MulIOp::create(b, loc, a32, b32);
+ sum = arith::AddIOp::create(b, loc, acc, mul);
+ }
+
+ linalg::YieldOp::create(b, loc, sum);
+ });
+
+ auto outDimSize = outType.getShape().size();
+
+ llvm::SmallVector<utils::IteratorType> iters(outDimSize,
+ utils::IteratorType::parallel);
+
+ llvm::ArrayRef<utils::IteratorType> iterRef = iters;
+
+ // ---- 4. Add original output + truncate ----
+ auto oldOut = genericOp.getDpsInitOperand(0)->get();
+ auto resultType = outType;
+
+ auto finalGeneric = linalg::GenericOp::create(
+ rewriter, loc, resultType, ValueRange{newGeneric.getResult(0), oldOut},
+ ValueRange{tensor::EmptyOp::create(rewriter, loc, outType.getShape(),
+ outType.getElementType())},
+ llvm::ArrayRef<AffineMap>{rewriter.getMultiDimIdentityMap(outDimSize),
+ rewriter.getMultiDimIdentityMap(outDimSize),
+ rewriter.getMultiDimIdentityMap(outDimSize)},
+ iterRef, [&](OpBuilder &b, Location loc, ValueRange args) {
+ auto acc = args[0];
+ auto accActual = args[1];
+
+ Value cast;
+ if (outType.getElementType().isBF16()) {
+ auto accActualF32 =
+ arith::ExtFOp::create(b, loc, opType, accActual);
+
+ auto sum = arith::AddFOp::create(b, loc, acc, accActualF32);
+
+ cast = arith::TruncFOp::create(b, loc, ipType, sum);
+ }
+
+ if (outType.getElementType().isSignlessInteger(8)) {
+ auto accActualI32 =
+ arith::ExtSIOp::create(b, loc, opType, accActual);
+
+ auto sum = arith::AddIOp::create(b, loc, acc, accActualI32);
+
+ cast = arith::TruncIOp::create(b, loc, ipType, sum);
+ }
+
+ linalg::YieldOp::create(b, loc, cast);
+ });
+
+ // ---- 5. Replace ----
+ rewriter.replaceOp(genericOp, finalGeneric.getResult(0));
+
+ return success();
+ }
+};
+
+} // namespace
+
+void x86::populateConvertLinalgGenericTo32BitAccumulationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ConvertLinalgGenericTo32BitAccumulation>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir b/mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir
new file mode 100644
index 0000000000000..2135797a6640e
--- /dev/null
+++ b/mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir
@@ -0,0 +1,290 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+!tensorA = tensor<32x32x16x2xbf16>
+!tensorB = tensor<32x16x32x2xbf16>
+!tensorC = tensor<32x32xbf16>
+
+func.func @brgemm_bf16(%arg0: tensor<8x32x32x32xbf16>, %arg1: tensor<32x32x16x32x2xbf16>, %arg2: tensor<8x32x32x32xbf16>) -> tensor<8x32x32x32xbf16> {
+ %expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2]
+ : tensor<8x32x32x32xbf16> into tensor<8x32x32x16x2xbf16>
+
+ %0 = scf.forall (%arg3, %arg4) in (8, 32) shared_outs(%arg5 = %arg2) -> (tensor<8x32x32x32xbf16>) {
+ %extracted_slice = tensor.extract_slice %expanded[%arg3, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1]
+ : tensor<8x32x32x16x2xbf16> to !tensorA
+ %extracted_slice_0 = tensor.extract_slice %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1]
+ : tensor<32x32x16x32x2xbf16> to !tensorB
+ %extracted_slice_1 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
+ : tensor<8x32x32x32xbf16> to !tensorC
+
+ %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "reduction", "parallel", "parallel", "reduction"]}
+ ins(%extracted_slice, %extracted_slice_0 : !tensorA, !tensorB) outs(%extracted_slice_1 : !tensorC) {
+ ^bb0(%in: bf16, %in_2: bf16, %out: bf16):
+ %2 = arith.mulf %in, %in_2 : bf16
+ %3 = arith.addf %out, %2 : bf16
+ linalg.yield %3 : bf16
+ } -> !tensorC
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %1 into %arg5[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
+ : !tensorC into tensor<8x32x32x32xbf16>
+ }
+ }
+ return %0 : tensor<8x32x32x32xbf16>
+}
+
+// CHECK-LABEL: @brgemm_bf16
+// CHECK: tensor.empty() : tensor<32x32xf32>
+// CHECK: linalg.fill ins(%cst : f32) outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<32x32x16x2xbf16>, tensor<32x16x32x2xbf16>) outs({{.*}} : tensor<32x32xf32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<32x32x16x2xbf16>, tensor<32x16x32x2xbf16>) outs({{.*}} : tensor<32x32xbf16>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<32x32xf32>, tensor<32x32xbf16>) outs({{.*}} : tensor<32x32xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @batch_matmul_bf16(%arg0: tensor<16x24x32x2xbf16>, %arg1: tensor<16x32x128x2xbf16>, %arg2: tensor<16x24x128xbf16>) -> tensor<16x24x128xbf16> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs(%arg2 : tensor<16x24x128xbf16>) {
+ ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
+ %1 = arith.mulf %in, %in_0 : bf16
+ %2 = arith.addf %out, %1 : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<16x24x128xbf16>
+ return %0 : tensor<16x24x128xbf16>
+}
+
+// CHECK-LABEL: @batch_matmul_bf16
+// CHECK: tensor.empty() : tensor<16x24x128xf32>
+// CHECK: linalg.fill ins(%cst : f32) outs(%0 : tensor<16x24x128xf32>) -> tensor<16x24x128xf32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<16x24x128xf32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<16x24x128xbf16>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x128xf32>, tensor<16x24x128xbf16>) outs({{.*}} : tensor<16x24x128xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d5, d0, d1, d2, d3, d4) -> (d5, d0, d2, d4, d1)>
+#map1 = affine_map<(d5, d0, d1, d2, d3, d4) -> (d5, d0, d4, d3, d1)>
+#map2 = affine_map<(d5, d0, d1, d2, d3, d4) -> (d5, d0, d2, d3)>
+
+func.func @matmul_many_dim_bf16(%arg0: tensor<2x16x24x32x2xbf16>, %arg1: tensor<2x16x32x128x2xbf16>, %arg2: tensor<2x16x24x128xbf16>) -> tensor<2x16x24x128xbf16> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x16x24x32x2xbf16>, tensor<2x16x32x128x2xbf16>) outs(%arg2 : tensor<2x16x24x128xbf16>) {
+ ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
+ %1 = arith.mulf %in, %in_0 : bf16
+ %2 = arith.addf %out, %1 : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<2x16x24x128xbf16>
+ return %0 : tensor<2x16x24x128xbf16>
+}
+
+// CHECK-LABEL: @matmul_many_dim_bf16
+// CHECK: tensor.empty() : tensor<2x16x24x128xf32>
+// CHECK: linalg.fill ins(%cst : f32) outs(%0 : tensor<2x16x24x128xf32>) -> tensor<2x16x24x128xf32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<2x16x24x32x2xbf16>, tensor<2x16x32x128x2xbf16>) outs({{.*}} : tensor<2x16x24x128xf32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<2x16x24x32x2xbf16>, tensor<2x16x32x128x2xbf16>) outs({{.*}} : tensor<2x16x24x128xbf16>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<2x16x24x128xf32>, tensor<2x16x24x128xbf16>) outs({{.*}} : tensor<2x16x24x128xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d2, d3, d4) -> (d0, d2, d4)>
+#map1 = affine_map<(d0, d2, d3, d4) -> (d0, d4, d3)>
+#map2 = affine_map<(d0, d2, d3, d4) -> (d2, d3)>
+
+func.func @brgemm_flat_int8(%arg0: tensor<16x64x256xi8>, %arg1: tensor<16x256x128xi8>, %arg2: tensor<64x128xi8>) -> tensor<64x128xi8> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs(%arg2 : tensor<64x128xi8>) {
+ ^bb0(%in: i8, %in_0: i8, %out: i8):
+ %1 = arith.muli %in, %in_0 : i8
+ %2 = arith.addi %out, %1 : i8
+ linalg.yield %2 : i8
+ } -> tensor<64x128xi8>
+ return %0 : tensor<64x128xi8>
+}
+
+// CHECK-LABEL: @brgemm_flat_int8
+// CHECK: tensor.empty() : tensor<64x128xi32>
+// CHECK: linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<64x128xi32>) -> tensor<64x128xi32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs({{.*}} : tensor<64x128xi32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs({{.*}} : tensor<64x128xi8>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<64x128xi32>, tensor<64x128xi8>) outs({{.*}} : tensor<64x128xi8>) {
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.convert_linalg_...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/190779
More information about the Mlir-commits
mailing list