[Mlir-commits] [mlir] [mlir][x86]Convert a linalg.generic with bf16/i8 accumulation to f32/i32 accumulation (PR #190779)

Arun Thangamani llvmlistbot at llvm.org
Tue Apr 7 04:45:53 PDT 2026


https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/190779

Rewrites a `linalg.generic` from low-precision (bf16/i8) to high-precision accumulation (f32/i32). Performs compute (mul + add) in higher precision, starting from a zero-initialized accumulator. Then adds the original output and casts (truncates) back to the original type.

>From 3a717c6176d49319eecc429a4d14cf46af789640 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 7 Apr 2026 04:42:24 -0700
Subject: [PATCH] convert a linalg.generic with bf16/i8 acc to f32/i32 acc

---
 .../X86/TransformOps/X86TransformOps.td       |  11 +
 mlir/include/mlir/Dialect/X86/Transforms.h    |   8 +
 .../X86/TransformOps/X86TransformOps.cpp      |   5 +
 .../lib/Dialect/X86/Transforms/CMakeLists.txt |   1 +
 ...onvertLinalgGenericTo32BitAccumulation.cpp | 214 +++++++++++++
 .../X86/linalg-generic-to-32bit-acc.mlir      | 290 ++++++++++++++++++
 6 files changed, 529 insertions(+)
 create mode 100644 mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp
 create mode 100644 mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir

diff --git a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
index c474cfb47d003..15aba2a21c4a2 100644
--- a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
+++ b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
@@ -82,5 +82,16 @@ def ApplyVectorContractToAMXDotProductPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyConvertLinalgGenericTo32BitAccumulationPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects pattern to convert a linalg.generic from low-precision (bf16/i8) to high-precision
+      accumulation (f32/i32) and finally tuncates the output back to original type.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86/Transforms.h b/mlir/include/mlir/Dialect/X86/Transforms.h
index 6ebba5e94ec7c..749e82a4b4e36 100644
--- a/mlir/include/mlir/Dialect/X86/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86/Transforms.h
@@ -110,6 +110,14 @@ void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
 // Int8).
 void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns);
 
+// Rewrites a linalg.generic from low-precision (bf16/i8) to high-precision
+// accumulation (f32/i32).
+// Performs compute (mul + add) in higher precision, starting from a
+// zero-initialized accumulator. Then adds the original output and casts
+// (truncates) back to the original type.
+void populateConvertLinalgGenericTo32BitAccumulationPatterns(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
index 390b21e12b0ed..fbddf19be8848 100644
--- a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
+++ b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
@@ -52,6 +52,11 @@ void mlir::transform::ApplyVectorContractToAMXDotProductPatternsOp::
   x86::populateVectorContractToAMXDotProductPatterns(patterns);
 }
 
+void mlir::transform::ApplyConvertLinalgGenericTo32BitAccumulationPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86::populateConvertLinalgGenericTo32BitAccumulationPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
index 9c3695536cda9..a4bcb17f9fdb0 100644
--- a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRX86Transforms
   SinkVectorProducerOps.cpp
   ShuffleVectorFMAOps.cpp
   VectorContractToAMXDotProduct.cpp
+  ConvertLinalgGenericTo32BitAccumulation.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp b/mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp
new file mode 100644
index 0000000000000..3f83006debbe5
--- /dev/null
+++ b/mlir/lib/Dialect/X86/Transforms/ConvertLinalgGenericTo32BitAccumulation.cpp
@@ -0,0 +1,214 @@
+//===- ConvertLinalgGenericTo32BitAccumulation.cpp------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/X86/Transforms.h"
+#include "mlir/Dialect/X86/X86Dialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+// Rewrites a linalg.generic from low-precision (bf16/i8) to high-precision
+// accumulation (f32/i32).
+// Performs compute (mul + add) in higher precision, starting from a
+// zero-initialized accumulator. Then adds the original output and casts
+// (truncates) back to the original type.
+//
+// Example:
+// Input:
+// linalg.generic ins(tensor<16x32xbf>, tensor<32x48xbf16>)
+// outs(tensor<16x48xbf16) { 	arith.multf : bf16 	arith.addf : bf16 } - >
+// tensor<16x48xbf16>
+//
+// Output:
+// linalg.fill ins(f32) outs(tensor<16x48xf32>) -> tensor<16x48xf32>
+// linalg.generic ins(tensor<16x32xbf>, tensor<32x48xbf16>)
+// outs(tensor<16x48xf32) { 	%a = arith.extf %in : bf16 to f32 	%b = arith.extf
+// %in_2 : bf16 to f32 	%c = arith.mulf %a, %b : f32 	arith.addf %out, %c : f32 }
+// -> tensor<16x48xf32>
+//
+// linalg.generic ins(tensor<16x48xf32>, tensor<16x48xbf16>)
+// outs(tensor<16x48xbf16>) { 	%a = arith.extf %in_2 : bf16 to f32 	%b =
+// arith.addf %in, %a : f32 	%c = arith.truncf %b : f32 to bf16 } ->
+// tensor<16x48xbf16>
+//
+struct ConvertLinalgGenericTo32BitAccumulation
+    : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!genericOp.hasPureTensorSemantics())
+      return rewriter.notifyMatchFailure(genericOp,
+                                         "Support only for tensor type.");
+
+    if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+      return rewriter.notifyMatchFailure(genericOp,
+                                         "Needed two input tensors.");
+
+    auto outType =
+        llvm::dyn_cast<RankedTensorType>(genericOp.getResult(0).getType());
+
+    if (!outType)
+      return rewriter.notifyMatchFailure(genericOp, "No output type detected.");
+
+    if (!outType.getElementType().isBF16() &&
+        !outType.getElementType().isSignlessInteger(8))
+      return rewriter.notifyMatchFailure(
+          genericOp, "The outs type should be BF16 or Int8.");
+
+    Type ipType = rewriter.getBF16Type();
+    Type opType = rewriter.getF32Type();
+
+    if (outType.getElementType().isSignlessInteger(8)) {
+      ipType = rewriter.getIntegerType(8);
+      opType = rewriter.getIntegerType(32);
+    }
+
+    if (outType.getElementType().isBF16()) {
+      for (Operation &innerOp : genericOp.getRegion().front()) {
+        if (isa<arith::MulFOp, arith::AddFOp, linalg::YieldOp>(innerOp))
+          continue;
+
+        return rewriter.notifyMatchFailure(
+            genericOp,
+            "Upsupported operations inside linalg.generic's region.");
+      }
+    }
+
+    if (outType.getElementType().isSignlessInteger(8)) {
+      for (Operation &innerOp : genericOp.getRegion().front()) {
+        if (isa<arith::MulIOp, arith::AddIOp, linalg::YieldOp>(innerOp))
+          continue;
+
+        return rewriter.notifyMatchFailure(
+            genericOp,
+            "Upsupported operations inside linalg.generic's region.");
+      }
+    }
+
+    auto loc = genericOp.getLoc();
+    auto tensorType = RankedTensorType::get(outType.getShape(), opType);
+
+    // tensor.empty
+    auto empty =
+        tensor::EmptyOp::create(rewriter, loc, outType.getShape(), opType);
+
+    auto zeroAttr = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0);
+    auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
+    if (outType.getElementType().isSignlessInteger(8)) {
+      auto zeroAttrI32 =
+          rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0);
+      zero = arith::ConstantOp::create(rewriter, loc, zeroAttrI32);
+    }
+
+    // fill
+    auto fill = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+                                       ValueRange{empty})
+                    .getResult(0);
+
+    // ---- 3. Build new linalg.generic (32 accumulation) ----
+    auto newGeneric = linalg::GenericOp::create(
+        rewriter, loc,
+        tensorType,               // result type
+        genericOp.getDpsInputs(), // same inputs
+        fill,                     // new init
+        genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+        [&](OpBuilder &b, Location loc, ValueRange args) {
+          // args: bf16/I8, bf16/I8, f32/I32
+          auto a = args[0];
+          auto bval = args[1];
+          auto acc = args[2];
+
+          Value sum;
+          if (outType.getElementType().isBF16()) {
+            // cast inputs
+            auto a32 = arith::ExtFOp::create(b, loc, opType, a);
+            auto b32 = arith::ExtFOp::create(b, loc, opType, bval);
+
+            // mul + add
+            auto mul = arith::MulFOp::create(b, loc, a32, b32);
+            sum = arith::AddFOp::create(b, loc, acc, mul);
+          }
+
+          if (outType.getElementType().isSignlessInteger(8)) {
+            // cast inputs
+            auto a32 = arith::ExtSIOp::create(b, loc, opType, a);
+            auto b32 = arith::ExtSIOp::create(b, loc, opType, bval);
+
+            // mul + add
+            auto mul = arith::MulIOp::create(b, loc, a32, b32);
+            sum = arith::AddIOp::create(b, loc, acc, mul);
+          }
+
+          linalg::YieldOp::create(b, loc, sum);
+        });
+
+    auto outDimSize = outType.getShape().size();
+
+    llvm::SmallVector<utils::IteratorType> iters(outDimSize,
+                                                 utils::IteratorType::parallel);
+
+    llvm::ArrayRef<utils::IteratorType> iterRef = iters;
+
+    // ---- 4. Add original output + truncate ----
+    auto oldOut = genericOp.getDpsInitOperand(0)->get();
+    auto resultType = outType;
+
+    auto finalGeneric = linalg::GenericOp::create(
+        rewriter, loc, resultType, ValueRange{newGeneric.getResult(0), oldOut},
+        ValueRange{tensor::EmptyOp::create(rewriter, loc, outType.getShape(),
+                                           outType.getElementType())},
+        llvm::ArrayRef<AffineMap>{rewriter.getMultiDimIdentityMap(outDimSize),
+                                  rewriter.getMultiDimIdentityMap(outDimSize),
+                                  rewriter.getMultiDimIdentityMap(outDimSize)},
+        iterRef, [&](OpBuilder &b, Location loc, ValueRange args) {
+          auto acc = args[0];
+          auto accActual = args[1];
+
+          Value cast;
+          if (outType.getElementType().isBF16()) {
+            auto accActualF32 =
+                arith::ExtFOp::create(b, loc, opType, accActual);
+
+            auto sum = arith::AddFOp::create(b, loc, acc, accActualF32);
+
+            cast = arith::TruncFOp::create(b, loc, ipType, sum);
+          }
+
+          if (outType.getElementType().isSignlessInteger(8)) {
+            auto accActualI32 =
+                arith::ExtSIOp::create(b, loc, opType, accActual);
+
+            auto sum = arith::AddIOp::create(b, loc, acc, accActualI32);
+
+            cast = arith::TruncIOp::create(b, loc, ipType, sum);
+          }
+
+          linalg::YieldOp::create(b, loc, cast);
+        });
+
+    // ---- 5. Replace ----
+    rewriter.replaceOp(genericOp, finalGeneric.getResult(0));
+
+    return success();
+  }
+};
+
+} // namespace
+
+void x86::populateConvertLinalgGenericTo32BitAccumulationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ConvertLinalgGenericTo32BitAccumulation>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir b/mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir
new file mode 100644
index 0000000000000..2135797a6640e
--- /dev/null
+++ b/mlir/test/Dialect/X86/linalg-generic-to-32bit-acc.mlir
@@ -0,0 +1,290 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+!tensorA = tensor<32x32x16x2xbf16>
+!tensorB = tensor<32x16x32x2xbf16>
+!tensorC = tensor<32x32xbf16>
+
+func.func @brgemm_bf16(%arg0: tensor<8x32x32x32xbf16>, %arg1: tensor<32x32x16x32x2xbf16>, %arg2: tensor<8x32x32x32xbf16>) -> tensor<8x32x32x32xbf16> {
+  %expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] 
+		: tensor<8x32x32x32xbf16> into tensor<8x32x32x16x2xbf16>
+
+  %0 = scf.forall (%arg3, %arg4) in (8, 32) shared_outs(%arg5 = %arg2) -> (tensor<8x32x32x32xbf16>) {
+    %extracted_slice = tensor.extract_slice %expanded[%arg3, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] 
+		: tensor<8x32x32x16x2xbf16> to !tensorA
+    %extracted_slice_0 = tensor.extract_slice %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] 
+		: tensor<32x32x16x32x2xbf16> to !tensorB
+    %extracted_slice_1 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] 
+		: tensor<8x32x32x32xbf16> to !tensorC
+
+    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = 
+		["reduction", "reduction", "parallel", "parallel", "reduction"]} 
+		ins(%extracted_slice, %extracted_slice_0 : !tensorA, !tensorB) outs(%extracted_slice_1 : !tensorC) {
+    ^bb0(%in: bf16, %in_2: bf16, %out: bf16):
+      %2 = arith.mulf %in, %in_2 : bf16
+      %3 = arith.addf %out, %2 : bf16
+      linalg.yield %3 : bf16
+    } -> !tensorC
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %1 into %arg5[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] 
+	: !tensorC into tensor<8x32x32x32xbf16>
+    }
+  }
+  return %0 : tensor<8x32x32x32xbf16>
+}
+
+// CHECK-LABEL: @brgemm_bf16
+// CHECK: tensor.empty() : tensor<32x32xf32>
+// CHECK: linalg.fill ins(%cst : f32) outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32> 
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<32x32x16x2xbf16>, tensor<32x16x32x2xbf16>) outs({{.*}} : tensor<32x32xf32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<32x32x16x2xbf16>, tensor<32x16x32x2xbf16>) outs({{.*}} : tensor<32x32xbf16>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<32x32xf32>, tensor<32x32xbf16>) outs({{.*}} : tensor<32x32xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @batch_matmul_bf16(%arg0: tensor<16x24x32x2xbf16>, %arg1: tensor<16x32x128x2xbf16>, %arg2: tensor<16x24x128xbf16>) -> tensor<16x24x128xbf16> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs(%arg2 : tensor<16x24x128xbf16>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
+    %1 = arith.mulf %in, %in_0 : bf16
+    %2 = arith.addf %out, %1 : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<16x24x128xbf16>
+  return %0 : tensor<16x24x128xbf16>
+}
+
+// CHECK-LABEL: @batch_matmul_bf16
+// CHECK: tensor.empty() : tensor<16x24x128xf32>
+// CHECK: linalg.fill ins(%cst : f32) outs(%0 : tensor<16x24x128xf32>) -> tensor<16x24x128xf32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<16x24x128xf32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<16x24x128xbf16>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x128xf32>, tensor<16x24x128xbf16>) outs({{.*}} : tensor<16x24x128xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d5, d0, d1, d2, d3, d4) -> (d5, d0, d2, d4, d1)>
+#map1 = affine_map<(d5, d0, d1, d2, d3, d4) -> (d5, d0, d4, d3, d1)>
+#map2 = affine_map<(d5, d0, d1, d2, d3, d4) -> (d5, d0, d2, d3)>
+
+func.func @matmul_many_dim_bf16(%arg0: tensor<2x16x24x32x2xbf16>, %arg1: tensor<2x16x32x128x2xbf16>, %arg2: tensor<2x16x24x128xbf16>) -> tensor<2x16x24x128xbf16> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x16x24x32x2xbf16>, tensor<2x16x32x128x2xbf16>) outs(%arg2 : tensor<2x16x24x128xbf16>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
+    %1 = arith.mulf %in, %in_0 : bf16
+    %2 = arith.addf %out, %1 : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<2x16x24x128xbf16>
+  return %0 : tensor<2x16x24x128xbf16>
+}
+
+// CHECK-LABEL: @matmul_many_dim_bf16
+// CHECK: tensor.empty() : tensor<2x16x24x128xf32>
+// CHECK: linalg.fill ins(%cst : f32) outs(%0 : tensor<2x16x24x128xf32>) -> tensor<2x16x24x128xf32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<2x16x24x32x2xbf16>, tensor<2x16x32x128x2xbf16>) outs({{.*}} : tensor<2x16x24x128xf32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<2x16x24x32x2xbf16>, tensor<2x16x32x128x2xbf16>) outs({{.*}} : tensor<2x16x24x128xbf16>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<2x16x24x128xf32>, tensor<2x16x24x128xbf16>) outs({{.*}} : tensor<2x16x24x128xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0,  d2, d3, d4) -> (d0, d2, d4)>
+#map1 = affine_map<(d0,  d2, d3, d4) -> (d0, d4, d3)>
+#map2 = affine_map<(d0,  d2, d3, d4) -> (d2, d3)>
+
+func.func @brgemm_flat_int8(%arg0: tensor<16x64x256xi8>, %arg1: tensor<16x256x128xi8>, %arg2: tensor<64x128xi8>) -> tensor<64x128xi8> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs(%arg2 : tensor<64x128xi8>) {
+  ^bb0(%in: i8, %in_0: i8, %out: i8):
+    %1 = arith.muli %in, %in_0 : i8
+    %2 = arith.addi %out, %1 : i8
+    linalg.yield %2 : i8
+  } -> tensor<64x128xi8>
+  return %0 : tensor<64x128xi8>
+}
+
+// CHECK-LABEL: @brgemm_flat_int8
+// CHECK: tensor.empty() : tensor<64x128xi32>
+// CHECK: linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<64x128xi32>) -> tensor<64x128xi32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs({{.*}} : tensor<64x128xi32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs({{.*}} : tensor<64x128xi8>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<64x128xi32>, tensor<64x128xi8>) outs({{.*}} : tensor<64x128xi8>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @negative_sub_op_in_generic(%arg0: tensor<16x24x32x2xbf16>, %arg1: tensor<16x32x128x2xbf16>, %arg2: tensor<16x24x128xbf16>) -> tensor<16x24x128xbf16> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs(%arg2 : tensor<16x24x128xbf16>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
+    %1 = arith.mulf %in, %in_0 : bf16
+    %2 = arith.subf %out, %1 : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<16x24x128xbf16>
+  return %0 : tensor<16x24x128xbf16>
+}
+
+// CHECK-LABEL: @negative_sub_op_in_generic
+// CHECK-NOT: tensor.empty() : tensor<32x32xf32>
+// CHECK-NOT: linalg.fill ins(%cst : f32) outs(%0 : tensor<16x24x128xf32>) -> tensor<16x24x128xf32>
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<16x24x128xf32>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<16x24x128xbf16>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x128xf32>, tensor<16x24x128xbf16>) outs({{.*}} : tensor<16x24x128xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @negative_f16_type(%arg0: tensor<16x24x32x2xf16>, %arg1: tensor<16x32x128x2xf16>, %arg2: tensor<16x24x128xf16>) -> tensor<16x24x128xf16> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x24x32x2xf16>, tensor<16x32x128x2xf16>) outs(%arg2 : tensor<16x24x128xf16>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f16):
+    %1 = arith.mulf %in, %in_0 : f16
+    %2 = arith.addf %out, %1 : f16
+    linalg.yield %2 : f16
+  } -> tensor<16x24x128xf16>
+  return %0 : tensor<16x24x128xf16>
+}
+
+// CHECK-LABEL: @negative_f16_type
+// CHECK-NOT: tensor.empty() : tensor<32x32xf32>
+// CHECK-NOT: linalg.fill ins(%cst : f32) outs(%0 : tensor<16x24x128xf32>) -> tensor<16x24x128xf32>
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xf16>, tensor<16x32x128x2xf16>) outs({{.*}} : tensor<16x24x128xf32>) {
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xf16>, tensor<16x32x128x2xf16>) outs({{.*}} : tensor<16x24x128xf16>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x128xf32>, tensor<16x24x128xf16>) outs({{.*}} : tensor<16x24x128xf16>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d2, d3, d4) -> (d0, d2, d4)>
+#map1 = affine_map<(d0, d2, d3, d4) -> (d0, d4, d3)>
+#map2 = affine_map<(d0, d2, d3, d4) -> (d2, d3)>
+
+func.func @negative_i32_acc(%arg0: tensor<16x64x256xi8>, %arg1: tensor<16x256x128xi8>, %arg2: tensor<64x128xi32>) -> tensor<64x128xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs(%arg2 : tensor<64x128xi32>) {
+  ^bb0(%in: i8, %in_0: i8, %out: i32):
+    %a = arith.extsi %in : i8 to i32
+    %b = arith.extsi %in_0 : i8 to i32
+    %1 = arith.muli %a, %b : i32
+    %2 = arith.addi %out, %1 : i32
+    linalg.yield %2 : i32
+  } -> tensor<64x128xi32>
+  return %0 : tensor<64x128xi32>
+}
+
+// CHECK-LABEL: @negative_i32_acc
+// CHECK-NOT: tensor.empty() : tensor<64x128xi32>
+// CHECK-NOT: linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<64x128xi32>) -> tensor<64x128xi32>
+// CHECK: linalg.generic {{.*}} ins({{.*}} : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs({{.*}} : tensor<64x128xi32>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x64x256xi8>, tensor<16x256x128xi8>) outs({{.*}} : tensor<64x128xi8>) {
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<64x128xi32>, tensor<64x128xi8>) outs({{.*}} : tensor<64x128xi8>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @negative_memref(%arg0: memref<16x24x32x2xbf16>, %arg1: memref<16x32x128x2xbf16>, %arg2: memref<24x128xbf16>) -> memref<24x128xbf16> {
+  linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<16x24x32x2xbf16>, memref<16x32x128x2xbf16>) outs(%arg2 : memref<24x128xbf16>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
+    %0 = arith.mulf %in, %in_0 : bf16
+    %1 = arith.addf %out, %0 : bf16
+    linalg.yield %1 : bf16
+  }
+  %alloc = memref.alloc() : memref<24x128xbf16>
+  memref.copy %arg2, %alloc : memref<24x128xbf16> to memref<24x128xbf16>
+  return %alloc : memref<24x128xbf16>
+}
+
+// CHECK-LABEL: @negative_memref
+// CHECK-NOT: tensor.empty() : tensor<24x128xf32>
+// CHECK-NOT: linalg.fill ins(%cst : f32) outs(%0 : tensor<24x128xf32>) -> tensor<24x128xf32>
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<16x24x32x2xbf16>, tensor<16x32x128x2xbf16>) outs({{.*}} : tensor<24x128xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<16x24x32x2xbf16>, memref<16x32x128x2xbf16>) outs(%arg2 : memref<24x128xbf16>) { 
+// CHECK-NOT: linalg.generic {{.*}} ins({{.*}} : tensor<24x128xf32>, tensor<24x128xbf16>) outs({{.*}} : tensor<24x128xbf16>) {
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.convert_linalg_generic_to_32_bit_accumulation
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list