[Mlir-commits] [mlir] [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect (PR #144307)
Andrzej Warzyński
llvmlistbot at llvm.org
Mon Jun 16 00:16:56 PDT 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/144307
This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`,
which are thin wrappers around the corresponding LLVM intrinsics:
- `llvm.intr.matrix.multiply`
- `llvm.intr.matrix.transpose`
These Vector dialect ops did not provide additional semantics or
abstraction beyond the LLVM intrinsics. Their removal simplifies the
lowering pipeline without losing any functionality.
The lowering chains:
- `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply`
- `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose`
are now replaced with:
- `vector.contract` → `llvm.intr.matrix.multiply`
- `vector.transpose` → `llvm.intr.matrix.transpose`
This was accomplished by directly replacing:
- `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp`
- `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp`
Note: This change introduces a build-time dependency from `Vector` to
`LLVM`. Ideally, such dependencies should be confined to dialect
conversion (`ConvertVectorToLLVM`). However, moving the lowering code
there would introduce notable churn, so this patch leaves the new
dependency in place for now.
>From a8df4b81ba9aaa574e6c77ffcd0d29d0784f81f7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 29 May 2025 20:35:03 +0100
Subject: [PATCH] [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp
from Vector dialect
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`,
which are thin wrappers around the corresponding LLVM intrinsics:
- `llvm.intr.matrix.multiply`
- `llvm.intr.matrix.transpose`
These Vector dialect ops did not provide additional semantics or
abstraction beyond the LLVM intrinsics. Their removal simplifies the
lowering pipeline without losing any functionality.
The lowering chains:
- `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply`
- `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose`
are now replaced with:
- `vector.contract` → `llvm.intr.matrix.multiply`
- `vector.transpose` → `llvm.intr.matrix.transpose`
This was accomplished by directly replacing:
- `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp`
- `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp`
Note: This change introduces a build-time dependency from `Vector` to
`LLVM`. Ideally, such dependencies should be confined to dialect
conversion (`ConvertVectorToLLVM`). However, moving the lowering code
there would introduce notable churn, so this patch leaves the new
dependency in place for now.
---
.../VectorToLLVM/ConvertVectorToLLVM.h | 6 -
.../mlir/Dialect/Vector/IR/VectorOps.td | 118 ------------------
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 41 ------
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 -
.../Transforms/EmulateUnsupportedFloats.cpp | 6 +-
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/LowerVectorContract.cpp | 12 +-
.../Transforms/LowerVectorTranspose.cpp | 3 +-
.../VectorToLLVM/vector-to-llvm.mlir | 80 ------------
mlir/test/Dialect/Vector/invalid.mlir | 29 -----
mlir/test/Dialect/Vector/ops.mlir | 16 ---
...tract-to-matrix-intrinsics-transforms.mlir | 4 +-
.../Vector/vector-transpose-lowering.mlir | 4 +-
.../Vector/CPU/flat-transpose-col.mlir | 8 +-
.../Vector/CPU/flat-transpose-row.mlir | 8 +-
.../Vector/CPU/matrix-multiply-col.mlir | 2 +-
.../Vector/CPU/matrix-multiply-row.mlir | 2 +-
17 files changed, 28 insertions(+), 314 deletions(-)
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index f6b09deb4e44c..cfb6cc313bc63 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -13,12 +13,6 @@
namespace mlir {
class LLVMTypeConverter;
-/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
-/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
-/// will be needed when invoking LLVM.
-void populateVectorToLLVMMatrixConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns);
-
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..9c95677ee50da 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2710,124 +2710,6 @@ def Vector_PrintOp :
}];
}
-//===----------------------------------------------------------------------===//
-// Ops used for supporting progressive lowering and conversion type changes.
-// The Ops are typically not used directly by higher level dialects, but are
-// used by intra-dialect rewriting rules to bring vector operations closer
-// to the hardware ISA.
-//===----------------------------------------------------------------------===//
-
-/// Vector dialect matrix multiplication op that operates on flattened 1-D
-/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
-/// This may seem redundant with vector.contract but it serves the purposes of
-/// more progressive lowering and localized type conversion on the path:
-/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
-def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
- PredOpTrait<"lhs operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>,
- PredOpTrait<"rhs operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 1>>]>,
- Arguments<(
- // TODO: tighten vector element types that make sense.
- ins FixedVectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
- FixedVectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
- I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
- Results<(
- outs FixedVectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
-{
- let summary = "Vector matrix multiplication op that operates on flattened 1-D"
- " MLIR vectors";
- let description = [{
- This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
- purposes of more progressive lowering and localized type conversion.
- Higher levels typically lower matrix multiplications into 'vector.contract'
- operations. Subsequent rewriting rule progressively lower these operations
- into 'vector.matrix_multiply' operations to bring the operations closer
- to the hardware ISA.
-
- The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
- and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
- <rhs_columns> and multiplies them. The result matrix is returned embedded in
- the result vector.
-
- Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
- support scalable vectors. Hence, this Op is only available for fixed-width
- vectors. Also see:
-
- http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
-
- Example:
-
- ```mlir
- %C = vector.matrix_multiply %A, %B
- { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
- (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
- ```
- }];
- let builders = [
- OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
- "unsigned":$lhsColumns, "unsigned":$rhsColumns),
- [{
- $_state.addOperands({lhs, rhs});
- $_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
- $_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
- $_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
- $_state.addTypes(VectorType::get(lhsRows * rhsColumns,
- ::llvm::cast<VectorType>(lhs.getType()).getElementType()));
- }]>,
- ];
- let assemblyFormat = "$lhs `,` $rhs attr-dict "
- "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
-}
-
-/// Vector dialect matrix transposition op that operates on flattened 1-D
-/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
-/// This may seem redundant with vector.transpose but it serves the purposes of
-/// more progressive lowering and localized type conversion on the path:
-/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
-def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
- PredOpTrait<"source operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(
- // TODO: tighten vector element types that make sense.
- ins FixedVectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
- I32Attr:$rows, I32Attr:$columns)>,
- Results<(
- outs FixedVectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
- let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
- let description = [{
- This is the counterpart of llvm.matrix.transpose in MLIR. It serves
- the purposes of more progressive lowering and localized type conversion.
- Higher levels typically lower matrix transpositions into 'vector.transpose'
- operations. Subsequent rewriting rule progressively lower these operations
- into 'vector.flat_transpose' operations to bring the operations closer
- to the hardware ISA.
-
- The `vector.flat_transpose` op treats the 1-D input `matrix` as
- a 2-D matrix with <rows> rows and <columns> columns, and returns the
- transposed matrix in flattened form in 'res'.
-
- Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
- support scalable vectors. Hence, this Op is only available for fixed-width
- vectors. Also see:
-
- http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
-
- Example:
-
- ```mlir
- %1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
- : vector<16xf32> -> vector<16xf32>
- ```
- }];
- let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
-}
-
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..bcdeccc54cf17 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -184,41 +184,6 @@ class VectorBitCastOpConversion
}
};
-/// Conversion pattern for a vector.matrix_multiply.
-/// This is lowered directly to the proper llvm.intr.matrix.multiply.
-class VectorMatmulOpConversion
- : public ConvertOpToLLVMPattern<vector::MatmulOp> {
-public:
- using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
- matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
- adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
- matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
- return success();
- }
-};
-
-/// Conversion pattern for a vector.flat_transpose.
-/// This is lowered directly to the proper llvm.intr.matrix.transpose.
-class VectorFlatTransposeOpConversion
- : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
-public:
- using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
- transOp, typeConverter->convertType(transOp.getRes().getType()),
- adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
- return success();
- }
-};
-
/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
@@ -2026,12 +1991,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorScalableStepOpLowering>(converter);
}
-void mlir::populateVectorToLLVMMatrixConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<VectorMatmulOpConversion>(converter);
- patterns.add<VectorFlatTransposeOpConversion>(converter);
-}
-
namespace {
struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 293e01a5bf4d4..dcc5ded02341f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -97,11 +97,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
- populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
useVectorAlignment);
- populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 62022bfb7df1e..f14264e2f55f3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
return converter.isLegal(op);
});
// Manually mark arithmetic-performing vector instructions.
- target.addDynamicallyLegalOp<
- vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
- vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
+ target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
+ vector::MultiDimReductionOp, vector::FMAOp,
+ vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
arith::ConstantOp, vector::SplatOp>();
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..08aef70fc4d8a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
+ MLIRLLVMDialect
MLIRVectorInterfaces
MLIRVectorUtils
)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c6627b5ec0d77..0e8c60f6a9a6e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -1280,12 +1281,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// %mtb = maybe_transpose
/// %flattened_a = vector.shape_cast %mta
/// %flattened_b = vector.shape_cast %mtb
-/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
+/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
/// %mtd = vector.shape_cast %flattened_d
/// %d = maybe_untranspose %mtd
/// %e = add %c, %d
/// ```
-/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when vectorContractLowering is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
@@ -1362,8 +1362,12 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
- Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
- rhsColumns);
+ Value mul = rew.create<LLVM::MatrixMultiplyOp>(
+ loc,
+ VectorType::get(lhsRows * rhsColumns,
+ cast<VectorType>(lhs.getType()).getElementType()),
+ lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+
mul = rew.create<vector::ShapeCastOp>(
loc,
VectorType::get({lhsRows, rhsColumns},
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..05fb613393584 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -338,7 +339,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
- Value trans = rewriter.create<vector::FlatTransposeOp>(
+ Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
loc, flattenedType, matrix, rows, columns);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
return success();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 64e51f5554628..72810b5dddaa3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
}
-// -----
-
-//===----------------------------------------------------------------------===//
-// vector.matrix_multiply
-//===----------------------------------------------------------------------===//
-
-// 4x16 16x3 4x3
-func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
- %C = vector.matrix_multiply %A, %B
- { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
- (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
- return %C: vector<12xf64>
-}
-// CHECK-LABEL: @matrix_ops
-// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
-// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
-// CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-
-// -----
-
-func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
- %C = vector.matrix_multiply %A, %B
- { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
- (vector<64xindex>, vector<48xindex>) -> vector<12xindex>
- return %C: vector<12xindex>
-}
-// CHECK-LABEL: @matrix_ops_index
-// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
-// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
-// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>
// -----
@@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> {
// -----
-//===----------------------------------------------------------------------===//
-// vector.flat_transpose
-//===----------------------------------------------------------------------===//
-
-func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
- %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
- : vector<16xf32> -> vector<16xf32>
- return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: func @flat_transpose
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
-// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME: vector<16xf32> into vector<16xf32>
-// CHECK: return %[[T]] : vector<16xf32>
-
-// -----
-
-func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
- %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
- : vector<16xindex> -> vector<16xindex>
- return %0 : vector<16xindex>
-}
-// CHECK-LABEL: func @flat_transpose_index
-// CHECK-SAME: %[[A:.*]]: vector<16xindex>
-// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64>
-// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
-// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME: vector<16xi64> into vector<16xi64>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex>
-// CHECK: return %[[T2]] : vector<16xindex>
-
-// -----
-
-func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
- %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
- : vector<16xf32> -> vector<16xf32>
- return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: func @flat_transpose
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
-// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME: vector<16xf32> into vector<16xf32>
-// CHECK: return %[[T]] : vector<16xf32>
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.gather
//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..a2ede475a1478 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1328,13 +1328,6 @@ func.func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
// -----
-func.func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
- // expected-error at +1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}}
- %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64>
-}
-
-// -----
-
func.func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
// expected-error at +1 {{expects operand to be a memref with identity layout}}
%0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
@@ -1937,28 +1930,6 @@ func.func @invalid_step_2d() {
// -----
-func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
- // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
- %c = vector.matrix_multiply %a, %b {
- lhs_rows = 2: i32,
- lhs_columns = 2: i32 ,
- rhs_columns = 2: i32 }
- : (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64>
-
- return
-}
-
-// -----
-
-func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32> {
- // expected-error @+1 {{'vector.flat_transpose' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[16]xf32>'}}
- %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
- : vector<[16]xf32> -> vector<[16]xf32>
- return %0 : vector<[16]xf32>
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..179e42156901f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -738,22 +738,6 @@ func.func @transpose_int_0d(%arg0: vector<i32>) -> vector<i32> {
return %0 : vector<i32>
}
-// CHECK-LABEL: @flat_transpose_fp
-func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> {
- // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
- %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32>
- // CHECK: return %[[X]] : vector<16xf32>
- return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: @flat_transpose_int
-func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
- // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
- %0 = vector.flat_transpose %arg0 { rows = 2: i32, columns = 8: i32 } : vector<16xi32> -> vector<16xi32>
- // CHECK: return %[[X]] : vector<16xi32>
- return %0 : vector<16xi32>
-}
-
// CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 08ac2ac5bb7d5..6960ab33925e5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -29,7 +29,7 @@
// CHECK: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
// CHECK: %[[b6:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
// CHECK: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
-// CHECK: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
+// CHECK: %[[mm1:.*]] = llvm.intr.matrix.multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
// CHECK: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK: %[[mm3:.*]] = vector.insert %[[mm2]], %[[ub_1]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
@@ -44,7 +44,7 @@ func.func @matmul(%arg0: vector<2x4xf32>,
}
// CHECK-LABEL: func @matmul_scalable
-// CHECK-NOT: vector.matrix_multiply
+// CHECK-NOT: llvm.intr.matrix.multiply
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
%arg1: vector<4x[3]xf32>,
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index a730f217f027d..ca130538c908a 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -139,7 +139,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @transpose(
func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
// CHECK: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
- // CHECK: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32>
+ // CHECK: llvm.intr.matrix.transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
// CHECK: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
return %0 : vector<4x2xf32>
@@ -150,7 +150,7 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
// CHECK-LABEL: func @transpose_scalable(
func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> {
// CHECK-NOT: vector.shape_cast
- // CHECK-NOT: vector.flat_transpose
+ // CHECK-NOT: llvm.intr.matrix.transpose
// CHECK: vector.transpose
%0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
return %0 : vector<[4]x2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir
index b414242b34cc0..86bd0b1e09763 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir
@@ -57,10 +57,10 @@ func.func @entry() {
// ( 1, 4 ) -> ( 3, 4, 5 )
// ( 2, 5 )
//
- %d = vector.flat_transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
- %e = vector.flat_transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
- %f = vector.flat_transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> -> vector<6xf64>
- %g = vector.flat_transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> -> vector<6xf64>
+ %d = llvm.intr.matrix.transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+ %e = llvm.intr.matrix.transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+ %f = llvm.intr.matrix.transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> into vector<6xf64>
+ %g = llvm.intr.matrix.transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> into vector<6xf64>
vector.print %d : vector<4xf64>
vector.print %e : vector<4xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir
index 95b178e04a2bb..55103bc686fb2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir
@@ -57,10 +57,10 @@ func.func @entry() {
// ( 2, 3 ) -> ( 1, 3, 5 )
// ( 4, 5 )
//
- %d = vector.flat_transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
- %e = vector.flat_transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64>
- %f = vector.flat_transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> -> vector<6xf64>
- %g = vector.flat_transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> -> vector<6xf64>
+ %d = llvm.intr.matrix.transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+ %e = llvm.intr.matrix.transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64>
+ %f = llvm.intr.matrix.transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> into vector<6xf64>
+ %g = llvm.intr.matrix.transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> into vector<6xf64>
vector.print %d : vector<4xf64>
vector.print %e : vector<4xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir
index 8f75ec98465ca..09941192cbc42 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir
@@ -39,7 +39,7 @@ func.func @entry() {
// x = |/ | column-major!
// ( 1, 3 ) (5, 7) ( 19, 27 )
//
- %c = vector.matrix_multiply %a, %b
+ %c = llvm.intr.matrix.multiply %a, %b
{ lhs_rows = 2: i32, lhs_columns = 2: i32 , rhs_columns = 2: i32 }
: (vector<4xf64>, vector<4xf64>) -> vector<4xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir
index b7d27c45226ef..d5f511c8ac119 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir
@@ -39,7 +39,7 @@ func.func @entry() {
// x =
// ( 2, 3 ) (6, 7) ( 26, 31 )
//
- %c = vector.matrix_multiply %a, %b
+ %c = llvm.intr.matrix.multiply %a, %b
{ lhs_rows = 2: i32, lhs_columns = 2: i32 , rhs_columns = 2: i32 }
: (vector<4xf64>, vector<4xf64>) -> vector<4xf64>
More information about the Mlir-commits
mailing list